AtCoder Regular Contest 093: E - Bichrome Spanning Tree
https://beta.atcoder.jp/contests/arc093/tasks/arc093_c
問題概要
N頂点M辺のグラフが与えられる。各辺にはコストがある。すべての辺をそれぞれ白か黒かで塗り分けるとき、以下の条件が満たされるような塗り方は何通りあるか。条件:「白の辺も黒の辺も少なくともひとつ含むような全域木が存在し、かつその中で最小コストを持つ全域木のコストはXに等しい」
N <= 1000
M <= 2000
解法
とりあえず元のグラフの最小全域木を適当にひとつ構成し、そのコストをYとおく。
まず簡単なケースとして、Y > X ならば条件を満たす全域木は存在しない(最小全域木よりコストの小さい全域木はないので)。
次に Y = X の場合。この場合、元のグラフの最小全域木としてありえるもののうち、少なくともひとつが黒と白の辺を両方含むのであればその塗り方はOKということになる。そしてそのような塗り方は「元のグラフの最小全域木に含まれうる辺(以下 safe edge と呼ぶ)の集合のうち、少なくともひとつの辺は集合内の他の辺と異なる色で塗る」という塗り方と同じである。仮にsafe edgeのうちひとつが黒、別のひとつが白だった場合、それら2つを両方含む最小全域木が必ず構成できる。逆にすべてのsafe edgeの色が同じだった場合、本来最小全域木に含まれない辺(=使うと無駄なコストが発生する辺)を少なくともひとつ使って全域木を作らざるをえず、Y = Xという前提が崩れる。よってsafe edgeの数をかぞえてそれらがすべて同じ色にならない場合の数を数えればこのケースは過不足なく数えられる。
safe edgeの数え方だが、例えば以下のような方法がある。最初にひとつ適当につくった最小全域木におけるすべての頂点対(u, v)間で「(u, v)間のパス上に含まれる辺のうち最大のコストを持つもののコスト」(max_edge(u, v))を計算しておく。ここで元のグラフの辺(a, b)をひとつとったとき、辺のコストがmax_edge(a, b)と同じであればその辺はsafe edgeで、大きければsafe edgeではない。なぜなら辺(a, b)をつないだあとでmax_edge(a, b)にあたる辺を木から削除すれば、木の連結を保ったままコストは (追加した辺 - 削除した辺) の分増えることになるからである(差が0ならコストが増えない=最小全域木のまま)。この要領で全部の辺をチェックすればsafe edgeの数をかぞえることができる。
最後に Y < X の場合。この場合、元の最小全域木からコストが(X-Y)増えた全域木を作ることが目的になる。このための条件は結構厳しくて、まず元の最小全域木は明らかに作れてはいけないので、safe edgeは全部同じ色にする必要がある。そのうえで余分な辺が使われるように塗るわけだが、余分な辺をひとつ追加することによって全域木のコストがどれくらい増えるかというと 辺(a, b)のコスト - max_edge(a, b) である。これは「絶対使わなければいけない余分な辺(a, b)」を全域木にひとつ追加したとき、木の形を保つために元の木のパス(a, b)上の辺をひとつ落とすことになるが、落とすべき辺は明らかに最もコストの大きい辺なので、そういうことになる。しかも元の最小全域木に対して追加できる余分な辺は高々ひとつである。余分な辺を追加するためには少なくともsafe edgeは全部同色で塗った上で追加したい辺を別の色で塗る必要があるが、仮にひとつでも違う色を全域木に組み込めればあとは何の辺を使おうが自由なので、safe edgeだけを使っていくのが最適になるからである。そして追加される余分な辺も、safe edgeと違う色を持っているもののうちで edge_cost(a, b) - max_edge(a, b) がもっとも小さいものが自動的に選ばれることになる。以上より目的を達成するためには (1) safe edgeは全部同じ色で塗る (2) edge_cost(a, b) - max_edge(a, b) が X-Yより小さい辺もsafe edgeと同じ色で塗る (3) edge_cost(a, b) - max_edge(a, b) が X-Y と等しい辺のうち、少なくともひとつは safe edgeと別の色 (4) 残りの辺は自由 という4つの条件を満たすように塗ればよい、ということになる。
感想
なんかこれまで書いたブログの中で最も意味不明の怪文書になってしまった気がする
感想として問題解いてるときは最小全域木が何通り作れるか?みたいなところに思考が行ってしまったんだけどそんなことを考える必要はなく、色を塗ると自動的に全域木が決まるのでそれを数える、というのが考えるべきことだったんだなあという感じです
コード (D言語)
import std.stdio, std.array, std.string, std.conv, std.algorithm; import std.typecons, std.range, std.random, std.math, std.container; import std.numeric, std.bigint, core.bitop, std.bitmanip; immutable long INF = 1L << 59; immutable long MOD = 10^^9 + 7; void main() { auto s = readln.split.map!(to!int); auto N = s[0]; auto M = s[1]; auto X = readln.chomp.to!long; auto G = new Tuple!(int, long)[][](N); auto E = new Tuple!(int, int, long)[](M); foreach (i; 0..M) { s = readln.split.map!(to!int); G[s[0]-1] ~= tuple(s[1]-1, s[2].to!long); G[s[1]-1] ~= tuple(s[0]-1, s[2].to!long); E[i] = tuple(s[0]-1, s[1]-1, s[2].to!long); } auto T = new Tuple!(int, long)[][](N); E.sort!"a[2] < b[2]"(); auto uf = new UnionFind(N); long mst_cost = 0; foreach (e; E) { if (uf.find(e[0]) == uf.find(e[1])) continue; T[e[0]] ~= tuple(e[1], e[2]); T[e[1]] ~= tuple(e[0], e[2]); uf.unite(e[0], e[1]); mst_cost += e[2]; } auto max_edge_cost = new long[][](N, N); void dfs(int n, int p, int root, long d) { max_edge_cost[root][n] = d; foreach (e; T[n]) { int m = e[0]; long c = e[1]; if (m != p) dfs(m, n, root, max(d, c)); } } bool is_safe_edge(Tuple!(int, int, long) e) { return max_edge_cost[e[0]][e[1]] >= e[2]; } foreach (i; 0..N) dfs(i, -1, i, 0); long ans = 0; if (mst_cost == X) { long safe_edges_cnt = E.map!(e => is_safe_edge(e)).sum; ans = (powmod(2, safe_edges_cnt, MOD) - 2) * powmod(2, M - safe_edges_cnt, MOD) % MOD; ans = (ans + MOD) % MOD; } else if (mst_cost < X) { long diff = X - mst_cost; long a = E.map!(e => e[2] - max_edge_cost[e[0]][e[1]] < diff).sum; long b = E.map!(e => e[2] - max_edge_cost[e[0]][e[1]] == diff).sum; long c = E.map!(e => e[2] - max_edge_cost[e[0]][e[1]] > diff).sum; ans = 2 * (powmod(2, b, MOD) - 1) % MOD * powmod(2, c, MOD) % MOD; ans = (ans + MOD) % MOD; } ans.writeln; } class UnionFind { int N; int[] table; this(int n) { N = n; table = new int[](N); fill(table, -1); } int find(int x) { return table[x] < 0 ? x : (table[x] = find(table[x])); } void unite(int x, int y) { x = find(x); y = find(y); if (x == y) return; if (table[x] > table[y]) swap(x, y); table[x] += table[y]; table[y] = x; } } long powmod(long a, long x, long m) { long ret = 1; while (x) { if (x % 2) ret = ret * a % m; a = a * a % m; x /= 2; } return ret; }