yukicoder No.719 Coprime

https://yukicoder.me/problems/no/719

問題概要

2〜Nの数の集合から、「どの相異なる2つの要素を取り出してもそれらは互いに素である」という条件を満たすように部分集合を作る。このような部分集合のうち、要素の合計が最大になるときの値を求めよ。

2 <= N <= 1262

解法

DP的な気持ちになると、「素因数に2を持つ集合」、「素因数に3を持つ集合」、「素因数に5を持つ集合」、... という感じで各素因数のグループごとに処理していきたくなるが、これらの集合は排反ではないので独立に計算していくことができない(6は素因数に2も持つし3も持つ)。

これをうまいこと回避するためには、「『最大の』素因数が2である集合」、「『最大の』素因数が3である集合」、... という形で考えればよい。これらの集合に同時に属する数はないので、集合ごと独立に計算していくことができる。

あとはこの集合ごとにDPをしていくことを考える。どの要素も互いに素という条件を満たすためにはどの素因数も高々1回までしか使えない(=素因数にa, b, cを持つ数を合計に組み入れた場合、以後もうa, b, cを素因数に持つ数は使えない)ことから、「どの素因数を既に使っているか」という情報も持っておく必要があるため、DPは以下のような形になる。

dp(n, S) := 最大の素因数がnである集合まで見て、使った素因数の集合がSであるときの最大合計値

ここで最大ケースであるN=1262までの素数はだいたい200個くらいあるため、すべての素数について使ったかどうかを記録しておくのは不可能である。しかし実は記録しておくのは高々sqrt(N)までの素数だけで問題ない。なぜなら

  • 最大の素因数がsqrt(N)以下の集合の場合、各要素の最大の素因数は当然sqrt(N)以下
  • 最大の素因数がsqrt(N)より大きい場合、「最大の素因数」を除けば他の素因数はsqrt(N)以下のものしかありえない(もし存在するなら、その2つを掛け合わせることでNを超えてしまう=矛盾)

ということになっているからである。よってsqrt(N)以下の素因数は複数の「最大の素因数別の集合」に含まれうるが、sqrt(N)より大きい素因数pは「最大の素因数がpである集合」内の要素にしか素因数として含まれ得ない。そのため大きい素因数pを使うかどうかは「最大の素因数がpである集合」を足すときだけしか考えなくてよいので、記録しておく必要がない、ということになる。

N=1262のときsqrt(N)以下の素数は11個しかないので、これは211で覚えておいても全然足りる。これで上のDPを時間内に解くことができる。

感想

最大の素因数で考える理由を理解するのが難しかった これからも集合をうまいこと切り分けたい気持ちになっていきたい

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, core.stdc.string;

immutable int INF = 1 << 29;

bool is_prime(int n) {
    if (n <= 1)
        return false;
    for (int i = 2; i * i <= n; ++i)
        if (n % i == 0)
            return false;
    return true;
}

void main() {
    auto N = readln.chomp.to!int;

    auto P = iota(2, N+1).filter!(i => is_prime(i)).array;
    int cnt = 0;
    for (int i = 0; i < P.length && P[i] * P[i] <= N; ++i, ++cnt) {}

    auto F = new int[][](N+1);
    foreach (i; 2..N+1) {
        int j = i;
        foreach (p; P) {
            while (j % p == 0) {
                F[i] ~= p;
                j /= p;
            }
        }
    }

    auto A = new int[](N+1);
    foreach (i; 2..N+1) foreach (j; 0..cnt) if (i % P[j] == 0) A[i] |= (1 << j);

    auto dp = new int[][](P.length+1, 1<<cnt);
    foreach (i; 0..P.length) fill(dp[i], -INF);
    dp[0][0] = 0;

    foreach (i; 0..P.length.to!int) {
        dp[i+1] = dp[i].dup;
        foreach (mask; 0..(1<<cnt)) {
            if (dp[i][mask] < 0) continue;
            for (int j = P[i]; j <= N; j += P[i]) {
                if (A[j] & mask) continue;
                if (F[j].back > P[i]) continue;
                int nmask = mask | A[j];
                dp[i+1][nmask] = max(dp[i+1][nmask], dp[i][mask] + j);
            }
        }
    }

    dp[P.length].reduce!max.writeln;
}