AtCoder Regular Contest 020: D - お菓子の国の旅行

https://atcoder.jp/contests/arc020/tasks/arc020_4

問題概要

N個の町があり、隣り合う町同士は道で結ばれている。町iと町(i+1)を結ぶ距離はDiとして与えられる。いまN個の町からK個の町を選び順に移動していくことを考える。このとき訪れる町に重複があってはならず、町から町への移動は必ず最短距離で行わなければならない。ただし始点と終点はどのように選んでもよい。このルールを守って移動する町の列を決めるとき、移動距離の合計がMの倍数となるようなものは何通りあるか求めよ。

N <= 100

M <= 30

K <= 10

解法

あるひとつの道に着目してその道が移動に何回使われるかということを考える。もし今見ている道が町3と町4を繋ぐものだった場合、その道が使われるのは(3以下の町→4以上の町)の移動と(4以上の町→3以下の町)の移動の2パターンである。つまり移動する町の列が与えられたとき、列の中で隣り合う2つの町のうち片方が3以下で片方が4以上となっている部分の数を数えればその道が使われる回数を出すことが出来る。例えばK=5で(5, 9, 1, 4, 3)の順に移動するとき、9->1, 1->4, 4->3 の移動で道3をまたぐので、道3は3回使われるということがわかる。

これをすべてのありうる移動列ごとに計算していけば当然答えは出せるが、移動列は最大で100C10*10!とかのパターンがあるのでまったく間に合わない。そこでさっきの考察に立ち戻ると、町iと町(i+1)を結ぶのコスト計算で必要なのは移動列そのものではなく「移動列に含まれる町のそれぞれが道の左にあるか右にあるか」という情報だけであった。このことを使うと移動列は愚直に持つ必要がなく、「左か右か」の情報だけを持ったbit列に潰すことができる。例えば先程と同じくK=5で移動列が(5, 9, 1, 4, 3)のとき、道3から見たこの移動列は(0, 0, 1, 0, 1)になる(道3より左の町を1, 右の町を0としている)。この形にしてしまっても、道が使われるのは 0 -> 1 への移動が行われる場所と 1 -> 0 への移動が行われる箇所だけということが依然としてわかるので、コストの計算は問題なく行うことができる。

以上を踏まえると、以下のようなbitDPを行うことができる。

dp(i, mask, m): i番目までの道までを見ていて、現在の道から見た移動列がmask(ビット列)で表され、コストの合計が m (MOD M) であるような場合の数

このDPの状態数はO(N * 2K * M)である。iからi+1への遷移ではmaskはそのまま維持されるか、どこかに1がひとつ増えることになる。例えば(5, 9, 4, 1, 3)の移動列は

  • 道1から見ると (0, 0, 0, 1, 0)
  • 道2から見ると (0, 0, 0, 1, 0)
  • 道3から見ると (0, 0, 0, 1, 1)
  • 道4から見ると (0, 0, 1, 1, 1)

という感じになっている。つまりmaskに1が1個増えるのは「さっきまでの道にとっては右だった町が今回の道からは左になる」パターンである。道は右にずれていくことを考えると1が減ったりすることはありえないし、同じ町には2度訪れないという制約を考えると1がいきなり2個以上増えたりしないこともわかる。この遷移はたかだかO(K)なので、状態数と併せてもO(NKM * 2K)で十分間に合う計算量になる。

感想

公式解説では何をどうbit列にしてるのかよくわからなかったのでkmjpさんのブログを見ました。木の問題とかを思うとひとつの辺に注目するというテク自体はそんなに突拍子がないものとは感じなくなってきたのはまあ成長といえるんだろうか 結局解けてないんだけど

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, core.stdc.string;

immutable long MOD = 10^^9 + 7;

void main() {
    auto s = readln.split.map!(to!int);
    auto N = s[0];
    auto M = s[1];
    auto K = s[2];
    auto A = iota(N-1).map!(_ => readln.chomp.to!int).array ~ 0;

    auto cross = new int[](1<<K);
    foreach (mask; 0..1<<K) cross[mask] = iota(K-1).map!(i => (mask & (0b11 << i)).popcnt == 1).sum;

    auto dp = new long[][][](N+1, 1<<K, M); // dp(i, mask): i番目の道まで見た / 訪問する町を順に並べた数列を「道iより左にあるか否か」でbitに変換したbitmask / コスト合計の mod M
    dp[0][0][0] = 1;

    foreach (i; 0..N) {
        foreach (mask; 0..1<<K) {
            foreach (j; 0..M) {
                int cost = (j + cross[mask] * A[i]) % M;
                (dp[i+1][mask][cost] += dp[i][mask][j]) %= MOD;
                foreach (k; 0..K) {
                    if (mask & (1<<k)) continue;
                    int nmask = mask | (1<<k);
                    int ncost = (j + cross[nmask] * A[i]) % M;
                    (dp[i+1][nmask][ncost] += dp[i][mask][j]) %= MOD;
                }
            }
        }
    }

    dp[N][(1<<K)-1][0].writeln;
}