いろはちゃんコンテスト Day2: F - 総入れ替え

https://atcoder.jp/contests/iroha2019-day2/tasks/iroha2019_day2_f

問題概要

箱が3つあり、それぞれの箱には100円玉がA1枚、B1枚、C1枚、50円玉がA2枚、B2枚、C2枚入っている。コインが1枚以上入っている箱をひとつ指定して、そこからランダムにコインを1枚取り出すという操作を2人で交互に繰り返すゲームを考える。双方が自分の手に入れる金額の期待値を最大化するように行動するとき、先手が最終的に手に入れる金額の期待値を求めよ。

0 <= A1, A2, B1, B2, C1, C2 <= 10

解法

dp[a1][a2][b1][b2][c1][c2]: 残っているコインの枚数がこれのとき先手が最終的に手に入れる金額の期待値

とすると、先手番の場合たとえば箱Aを選んだとき確率 a / (a+b) で 期待値 (dp[a1 -1][a2][b1][b2][c1][c2] + 100) 円となり、確率 b / (a+b) で 期待値 (dp[a1][a2 - 1][b1][b2][c1][c2] + 50) 円となる。なのでこのときの期待値は

a / (a+b) * (dp[a1 -1][a2][b1][b2][c1][c2] + 100) + b / (a+b) * (dp[a1][a2 - 1][b1][b2][c1][c2] + 50)

である。他の箱の場合も同様の計算をして最も大きい値のものをdp[a1][a2][b1][b2][c1][c2]として採用すればよい。後手番の場合は期待値の最も低いものをとる(決まった合計金額を2人で取り合うゲームなので、後手の期待値の最大化=先手の期待値の最小化)。後手のdp計算の場合 +100 とか +50 は入らないことに注意する。あとはdpのインデックスがゼロのときとかをよしなにやる。実装はメモ化再帰のほうがやりやすい(当社比)。

感想

期待値の問題に対してものすごい苦手意識があるが、この問題の場合はそもそもよくあるタイプの2人ゲームなのでそういう感じのメモ化再帰を丁寧にやっていくことを考えるべきだった

コード (C++)

#include <bits/stdc++.h>
using namespace std;
#define REP(i,n) for (int i=0;i<(n);i++)
#define REP2(i,m,n) for (int i=m;i<(n);i++)
typedef long long ll;

int A, B, C, D, E, F, parity;
double mem[11][11][11][11][11][11];

double dfs(int a, int b, int c, int d, int e, int f) {
    if (a == 0 && b == 0 && c == 0 && d == 0 && e == 0 && f == 0) return 0;
    if (a < 0 || b < 0 || c < 0 || d < 0 || e < 0 || f < 0) return 0;
    if (mem[a][b][c][d][e][f] > 0) return mem[a][b][c][d][e][f];
    
    if ((a + b + c + d + e + f) % 2 == parity) {
        double x = (a + b > 0) ? (dfs(a-1, b, c, d, e, f) + 100) * a / (a + b) + (dfs(a, b-1, c, d, e, f) + 50) * b / (a + b) : 0;
        double y = (c + d > 0) ? (dfs(a, b, c-1, d, e, f) + 100) * c / (c + d) + (dfs(a, b, c, d-1, e, f) + 50) * d / (c + d) : 0;
        double z = (e + f > 0) ? (dfs(a, b, c, d, e-1, f) + 100) * e / (e + f) + (dfs(a, b, c, d, e, f-1) + 50) * f / (e + f) : 0;
        return mem[a][b][c][d][e][f] = max({x, y, z});
    } else {
        double x = (a + b > 0) ? dfs(a-1, b, c, d, e, f) * a / (a + b) + dfs(a, b-1, c, d, e, f) * b / (a + b) : -1;
        double y = (c + d > 0) ? dfs(a, b, c-1, d, e, f) * c / (c + d) + dfs(a, b, c, d-1, e, f) * d / (c + d) : -1;
        double z = (e + f > 0) ? dfs(a, b, c, d, e-1, f) * e / (e + f) + dfs(a, b, c, d, e, f-1) * f / (e + f) : -1;
        double ret = 1 << 29;
        if (x != -1) ret = min(ret, x);
        if (y != -1) ret = min(ret, y);
        if (z != -1) ret = min(ret, z);
        return mem[a][b][c][d][e][f] = ret;
    }
}

void solve() {
    cin >> A >> B >> C >> D >> E >> F;
    parity = (A + B + C + D + E + F) % 2;

    cout.precision(20);
    cout << fixed << dfs(A, B, C, D, E, F) << endl;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    solve();
}