AtCoder Grand Contest 009 C - Division into Two

http://agc009.contest.atcoder.jp/tasks/agc009_c

問題概要

昇順ソート済みのN要素の数列Sが与えられる。この数列を2つの集合X, Yに分割する。ただしX, Yは以下の条件を満たしていなければならない。

  • Xに含まれるどの相異なる2要素も差の絶対値がA以上
  • Yに含まれるどの相異なる2要素も差の絶対値がB以上

条件を満たすような分割の仕方が何通りあるか求めよ。

1 <= N <= 105

解法

DPの計算量を落としていく系の問題。まず分割を「Sの要素を小さい順にX, Yどちらかに追加していく操作」と言い換えると、S[i]をX, Yに入れられるかどうかはX, Yそれぞれに最後に入れた要素だけに依存するため、以下のようなdpが立つ。

dp(i, x, y): S[i]を追加したときにXの最後の要素がS[x]でYの最後の要素がS[y]であるような分割の総数

このdpはO(N3)かかり、全然間に合わない。ここでx, yのうちどちらかは必ず直前に入れた要素になっていることを考えると、x, yのうちどちらかは必ず直前の数になっているはずである。つまり実質ありうるのはdp(i, i, y), dp(i, x, i)の2パターンしかない。ゆえに次元をひとつ落として以下のようなdpが可能になる。

dp(i, isX, j): S[i]を追加したとき、S[i]を入れた箱がX(もしくはY)であり、S[i]を入れなかった方の集合の最後の要素がS[j]であるような分割の総数

これで状態数が削減され、N×N×N のdpが N×2×N に落ちた。すなわちO(N2)である。しかしこれでもまだ時間が足りない。状態数はこれ以上落ちそうにないので、今度は遷移の方を工夫する。状態を(Xの最後の要素, Yの最後の要素)のように表すことにすると遷移は以下のようになる。

  • ① (S[i-1], S[j]) -> (S[i], S[j]) ※ S[i] - S[i-1] >= Aのときだけ可
  • ② (S[i-1], S[j]) -> (S[i-1], S[i]) ※ S[i] - S[j] >= Aのときだけ可
  • ③ (S[j], S[i-1]) -> (S[i], S[i-1]) ※ S[i] - S[j] >= Bのときだけ可
  • ④ (S[j], S[i-1]) -> (S[j], S[i]) ※ S[i] - S[i-1] >= Bのときだけ可

これらを上のdpの形で表すと

  • ① dp(i, X, j) += dp(i-1, X, j) if S[i] - S[i-1] >= A
  • ② dp(i, Y, i-1) += dp(i-1, X, j) if S[i] - S[j] >= A
  • ③ dp(i, X, i-1) += dp(i-1, Y, j) if S[i] - S[j] >= B
  • ④ dp(i, Y, j) += dp(i-1, Y, j) if S[i] - S[i-1] >= B

②, ③の形はつまり S[i] - S[j] >= A かつ j <= i - 2 であるようなすべてのjについてdp(i-1, X, j)を加算する、ということである。①, ④はj→jという遷移なので加算というよりはS[i]-S[i-1]である限りdp(j)の値が保存されるが、そうでなくなった瞬間そこの値が0になる、ということである(定性的に言うと、S[i-1]がXに入っていてS[i]-S[i-1] < A なら、 S[i]をXに入れることはできないので、S[j]が入ってる方に入れるしかなく、S[j]が最後の要素であるということがありえなくなってしまうということ)。

以上より実はdpの遷移においてjの値を個別に見ていく必要はなく、「あるjまでのdp(j)の合計をとる」「あるj以下のdp(j)をまとめて0にする」という操作だけが必要であることがわかる。このような操作を高速に行う方法はいろいろあると思うが、自分は以下のような実装を行った。

  • セグメント木を2本用意する。それぞれが上のdpでいうXかYかの部分に対応
  • それぞれのセグ木に対して「どこまでのjがゼロクリアされたか」を覚えておく変数を用意
  • あとは上のdpの通りに計算

これでO(NlogN)になる。セグ木でやってるところをがんばってなんとかするとO(N)にもなるらしい。

感想

何とか解けたが時間かかりまくった。このレベルを本番で通すのはまだまだ難しそう

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop;

immutable long MOD = 10^^9 + 7;

void main() {
    auto s = readln.split.map!(to!long);
    auto N = s[0].to!int;
    auto A = s[1];
    auto B = s[2];
    auto S = N.iota.map!(_ => readln.chomp.to!long).array;
    S ~= - 10L ^^ 18 - 10;
    S.sort();

    auto st1 = new SegmentTree(N + 1);
    auto st2 = new SegmentTree(N + 1);
    st1.add(0, 1);
    st2.add(0, 1);
    int lb1 = 0;
    int lb2 = 0;

    foreach (i; 2..N+1) {
        int j = S.assumeSorted.lowerBound(S[i] - A + 1).length.to!int - 1;
        int k = S.assumeSorted.lowerBound(S[i] - B + 1).length.to!int - 1;
        st2.add(i - 1, st1.sum(lb1, min(j, i - 2)));
        st1.add(i - 1, st2.sum(lb2, min(k, i - 2)));
        if (S[i] - S[i - 1] < A) lb2 = i - 1;
        if (S[i] - S[i - 1] < B) lb1 = i - 1;
    }

    writeln( (st1.sum(lb1, N + 1) + st2.sum(lb2, N + 1)) % MOD );
}

class SegmentTree {
    long[] table;
    int size;

    this(int n) {
        assert(bsr(n) < 29);
        size = 1 << (bsr(n) + 2);
        table = new long[](size);
    }

    void add(int pos, long num) {
        return add(pos, num, 0, 0, size/2-1);
    }

    void add(int pos, long num, int i, int left, int right) {
        (table[i] += num) %= MOD;
        if (left == right)
            return;
        auto mid = (left + right) / 2;
        if (pos <= mid)
            add(pos, num, i*2+1, left, mid);
        else
            add(pos, num, i*2+2, mid+1, right);
    }

    long sum(int pl, int pr) {
        if (pl > pr) return 0;
        return sum(pl, pr, 0, 0, size/2-1);
    }

    long sum(int pl, int pr, int i, int left, int right) {
        if (pl > right || pr < left)
            return 0;
        else if (pl <= left && right <= pr)
            return table[i];
        else
            return
                (sum(pl, pr, i*2+1, left, (left+right)/2) +
                 sum(pl, pr, i*2+2, (left+right)/2+1, right)) % MOD;
    }
}