AtCoder Grand Contest 002 D - Stamp Rally
http://agc002.contest.atcoder.jp/tasks/agc002_d
問題概要
N頂点M辺の無向グラフがあり、頂点には1~N, 辺には1~Mの番号がついている。またQ個のクエリがあり、ひとつのクエリでは2つの頂点番号x, yと整数zが与えられる。各クエリごとに、x, yの2つを出発点として合計でz個の頂点を訪れるとき、通る辺の最大の番号の最小値を求めよ。
N, M, Q <= 105
解法
まずクエリが1回だけの場合を考える。最初グラフのすべての辺がなくなったものとして、ここに番号の小さい順から辺を追加していくことを考えると、この操作の過程でx, yから訪れることができる頂点数が初めてzを超えるタイミングが答えになることがわかる。具体的にはUnionFindを用いて、(1)もしxとyが同じ連結成分に入っていればその連結成分のサイズ (2)違う連結成分であればそれぞれのサイズの合計 を求めれば「x, yから訪れることができる頂点数」が割り出せるので、UnionFindで辺をつなぐたびこの値を計算すればよい。
この方法はクエリが複数回のときにも通用しそうに見える。つまりUnionFindで辺をつないでいく操作は各クエリの内容に関係なく共通なので、前計算しておけばよさそうな気がする。ただしここで問題があり、普通のUnionFindでは辺をつなぐ操作は非可逆的に行われるため「i番目の辺をつないだときの状態」を取り出すことができない。もちろん愚直には辺を追加した時点でコピーを作って取っておけば後から見ることができるが、これは計算量がかかりすぎる(時間も空間も)。
これを解決するのが (部分)永続UnionFind である。このデータ構造では空間計算量(N+Q)で以下の3つの操作をいずれもO(logN)で行うことができる。
- unite(u, v): 頂点u, vをつなぐ
- find(u, t): t回目のuniteを行った時点での頂点uの親
- size(u, t): t回目のuniteを行った時点での頂点uが属する連結成分のサイズ
(仕組みとか実装に関してはhttps://camypaper.bitbucket.io/2016/12/18/adc2016/ がとてもわかりやすかったのでこちらを見てください)
これで各辺を追加した段階での特定頂点が属する連結成分のサイズをO(logN)で計算できるようになった。あとは各クエリごとに2分探索でzを超える瞬間の辺番号を求めればよい。合わせてO(Q(logN)2)となる。
感想
初めて永続的データ構造の類を使ったけど思ったよりゴテゴテしてなくてよかった(?)想定解はこれじゃなくて並列二分探索なるものらしいので要理解
コード (D言語)
import std.stdio, std.array, std.string, std.conv, std.algorithm; import std.typecons, std.range, std.random, std.math, std.container; import std.numeric, std.bigint, core.bitop; const int INF = 1 << 29; void main() { auto s = readln.split.map!(to!int); auto N = s[0]; auto M = s[1]; auto uf = new PersistentUnionFind(N); foreach (i; 0..M) { s = readln.split.map!(to!int); uf.unite(s[0]-1, s[1]-1); } auto Q = readln.chomp.to!int; while (Q--) { s = readln.split.map!(to!int); auto x = s[0]-1; auto y = s[1]-1; auto z = s[2]; int hi = M; int lo = 0; while (hi - lo > 1) { int size; int mid = (hi + lo) / 2; if (uf.find(x, mid) == uf.find(y, mid)) size = uf.size(x, mid); else size = uf.size(x, mid) + uf.size(y, mid); if (size >= z) hi = mid; else lo = mid; } hi.writeln; } } class PersistentUnionFind { int[][] rank; int[][] time; int[][] parent; int n; int global_time; this(int n) { this.n = n; rank = new int[][](n); time = new int[][](n); parent = new int[][](n); foreach (i; 0..n) { rank[i] ~= 1; time[i] ~= 0; parent[i] ~= i; } global_time = 0; } void unite(int u, int v) { global_time += 1; u = find(u, global_time); v = find(v, global_time); if (u == v) return; if (rank[u] < rank[v]) swap(u, v); int r = rank[u].back + rank[v].back; rank[u] ~= r; time[u] ~= global_time; parent[u] ~= u; rank[v] ~= r; time[v] ~= global_time; parent[v] ~= u; } int find(int u, int t) { if (parent[u].back == u) return u; if (time[u].back > t) return u; return find(parent[u].back, t); } int size(int u, int t) { int v = find(u, t); int i = time[v].assumeSorted.lowerBound(t+1).length.to!int; return rank[v][i-1]; } }