DISCO presents ディスカバリーチャンネル プログラミングコンテスト2016 本選: C - 特別講演「括弧列と塗り分け」

https://beta.atcoder.jp/contests/discovery2016-final/tasks/discovery_2016_final_c

問題概要

対応の取れた括弧列Sと非負整数Kが与えられる。Sの各文字を赤か青の2色で塗り分けるとき、すべての対応括弧の組(S[i], S[j])について、Sの区間[i, j]において赤と青の数の差がK以内になるような塗り方は何通りあるか。

|S|, K <= 3000

解法

木DPをした。括弧Aが括弧Bを包含しているときAがBの親となるようにすると、括弧の包含関係は木になる。また互いに包含関係にない括弧同士の塗り方は題意から独立である。このことから包含関係で木を作っていくと、1個以上の木(=森)ができる。それぞれの木でDPを行った結果を掛け合わせれば答えとなる。具体的なDPは

dp(n, k): 構築した森のノードnにあたる括弧(S[i], S[j])において、Sの区間[i, j]を赤と青の差がk以下になるよう塗る塗り方の総数

を順に各ノードでやっていくわけだが、これはまず子のDP値を全部求めた上でそれを使ってノード内でDPをすると求まる。

各ノードでO(NK2)?っぽいDPをやるので全部でO(N2 K2)かかりそうな気がするが、最大でも区間長×2の差しか生まれない(それ以上のkは見る必要がない)ことに気を付けて適切に枝刈りを入れるとオーダーが落ちて通る。こういう感じの木DPにおける典型知識らしい→参照: http://topcoder.g.hatena.ne.jp/iwiwi/20120428/1335635594

感想

気を付けず適切に枝刈りを入れなかったのでTLEを出しまくった

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, std.bitmanip;

void main() {
    immutable long MOD = 10^^9 + 7;
    
    auto S = readln.chomp;
    auto N = S.length.to!int;
    auto K = readln.chomp.to!int;
    K = min(K, 2*N);
    auto P = new Tuple!(int, int)[](N/2);
    auto G = new int[][](N/2);
    auto root = new bool[](N/2);
    root.fill(true);
    long inv2 = powmod(2, MOD-2, MOD);

    auto stack = new Tuple!(int, int)[](N/2);
    for (int i = 0, p = 0, sp = -1; i < N; ++i) {
        if (S[i] == '(') {
            if (sp >= 0) G[stack[sp][1]] ~= p, root[p] = false;
            stack[++sp] = tuple(i, p++);
        } else {
            P[stack[sp][1]] = tuple(stack[sp][0], i);
            sp--;
        }
    }

    auto dp = new long[][](N/2, K+1);
    auto tmp = new long[][](2, N+1);
    
    void dfs(int n) {
        foreach (m; G[n]) dfs(m);

        int M = P[n][1] - P[n][0] + 1;
        tmp[0].fill(0);
        tmp[1].fill(0);
        
        tmp[0][0] = 2;
        tmp[0][1] = 2;
        int cur = 0, tar = 1;
        
        foreach (m; G[n]) {
            tmp[tar].fill(0);
            for (int i = 0; i <= M; ++i) {
                if (tmp[cur][i] == 0) continue;
                for (int j = 0; j <= K/2 && i+j <= M && dp[m][j] > 0; ++j) {
                    (tmp[tar][i+j] += tmp[cur][i] * dp[m][j] % MOD * inv2 % MOD) %= MOD;
                }
                for (int j = 0; j <= K/2 && abs(i-j) <= M && dp[m][j] > 0; ++j) {
                    (tmp[tar][abs(i-j)] += tmp[cur][i] * dp[m][j] % MOD * inv2 % MOD) %= MOD;
                }
            }
            swap(cur, tar);
        }

        for (int i = 0; i <= K/2; ++i) dp[n][i] = tmp[cur][i];
    }

    
    long ans = 1;
    foreach (i; 0..N/2) {
        if (!root[i]) continue;
        dfs(i);
        long t = 0;
        foreach (j; 0..K/2+1) t = (t + dp[i][j]) % MOD;
        ans = ans * t % MOD;
    }

    ans.writeln;
}

long powmod(long a, long x, long m) {
    long ret = 1;
    while (x) {
        if (x % 2) ret = ret * a % m;
        a = a * a % m;
        x /= 2;
    }
    return ret;
}