yukicoder No.584 赤、緑、青の色塗り

https://yukicoder.me/problems/no/584

問題概要

1×Nのグリッドを以下のルールを守って赤、緑、青で彩色するとき、ありうる塗り方は何通りか。

  • R個のマスが赤、G個のマスが緑、B個のマスが青になるようにし、残りのマスは未彩色のまま残す
  • 3マス以上連続して色が塗られないようにする
  • 2マス以上連続して同じ色が塗られないようにする

N <= 3000

解法

正しい塗り方をすると、最終的には「2マス連続で色が塗られている箇所」「1マスだけ色が塗られている箇所」が空白マスに区切られて並んでいる形になる。

まず「2マス連続で色が塗られている箇所」が何個あるかを決めると、「1マス孤立して色が塗られている箇所」が何個あるかが決まり、さらに空白マスが何個残るかも決まる。その上でこれらの並べ方が何個あるかを考えると、「2マス」と「1マス」をそれぞれ空白の間に挿し込んでいく形になるので、(空白の数+1)C(「2マス」の数+「1マス」の数)となる。

次に並べたマスに色を塗っていくことを考える。まず赤の塗り方であるが、赤を「2マス」に何個塗って「1マス」に何個塗るかは総当たりする必要がある。なぜならどちらのマスに塗ったかによって次の色に残されるマス数が変わってくるからである(1マスの方に塗った場合は別の色で塗れない)。

こうして赤を塗ったあとは「2マス(片方が赤で塗られている)」「2マス(両方塗られていない)」「1マス(塗られていない)」の3種類のマスが残ることになる。ここに緑を塗っていくことを考えると、まず「2マス(両方塗られていない)」には必ず緑を塗る必要がある(そうしないと青で両方を塗ることになってしまう)。ゆえに緑の塗り方の総数は(「2マス(片方が赤で塗られている)」+「1マス(塗られていない)」)から (B-「2マス(両方塗られていない)」)を選ぶ総数となる。残った青の塗り方はここまでの手続きで自動的に1通りに定まるので考慮する必要はない。

以上より、最初の2マス・1マスの取り方を総当たりするのにO(N), 赤の塗り方を総当たりするのにO(N)かかるのでO(N2)で答えが出せる。コンビネーションは前計算しておく。

感想

めっちゃ泥臭いが、一歩一歩解法を詰めていける感じで良かった

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop;


void main() {
    immutable int MAX = 4000;
    immutable long MOD = 10^^9+7;

    auto modinv = new long[](MAX);
    modinv[0] = modinv[1] = 1;
    foreach(i; 2..MAX) {
        modinv[i] = modinv[MOD % i] * (MOD - MOD / i) % MOD;
    }

    auto f_mod = new long[](MAX);
    auto f_modinv = new long[](MAX);
    f_mod[0] = f_mod[1] = 1;
    f_modinv[0] = f_modinv[1] = 1;

    foreach(i; 2..MAX) {
        f_mod[i] = (i * f_mod[i-1]) % MOD;
        f_modinv[i] = (modinv[i] * f_modinv[i-1]) % MOD;
    }

    long comb(int n, int k) {
        if (n < k) return 0;
        return f_mod[n] * f_modinv[n-k] % MOD * f_modinv[k] % MOD;
    }

    
    auto s = readln.split.map!(to!int);
    auto N = s[0], R = s[1], G = s[2], B = s[3];
    auto M = max(R, G, B);
    auto RGB = R + G + B;
    long ans = 0;

    foreach (two; 0..RGB/2+1) {
        auto one = RGB - two * 2;
        auto emp = N - two * 2 - one;
        if (M > two + one) continue;
        if (emp + 1 < one + two) continue;
        long tmp1 = comb(emp+1, one) * comb(emp+1-one, two) % MOD;
        foreach (i; max(0, R-one)..min(R, two)+1) {
            long tmp2 = tmp1 * comb(two, i) % MOD * comb(one, R-i) % MOD;
            int one_rest = one - (R-i);
            int two_rest = two - i;
            if (G < two_rest) continue;
            tmp2 = tmp2 * comb(one_rest + i, G - two_rest) % MOD;
            tmp2 = tmp2 * powmod(2, two, MOD);
            ans = (ans + tmp2) % MOD;
        }
    }

    ans.writeln;
}


long powmod(long a, long x, long m) {
    long ret = 1;
    while (x) {
        if (x % 2) ret = ret * a % m;
        a = a * a % m;
        x /= 2;
    }
    return ret;
}