DISCO presents ディスカバリーチャンネル コードコンテスト2019 本戦: D - DISCO!
https://atcoder.jp/contests/ddcc2019-final/tasks/ddcc2019_final_d
問題概要
どの文字もDISCOの5文字のうちいずれか1文字であるような文字列Sが与えられる。この文字列に対する以下のQ個の問に答えよ。
問: Sの区間[l, r]の範囲内で長さ5の部分文字列(連続でなくともよい)を抜き出したとき、それが"DISCO"となっているようなものは何通りか。mod 232 で答えよ。
|S| <= 106
Q <= 105
解法
自分がこの問題を解いたときは類題(ドワコン2018予選のC問題 k-DMC)を思い出しながら解いていた。この類題では、解法の1ステップとして以下のような小問題を解く必要があった(便宜上簡略化してます)。
問: いま見ている区間には 文字Aがa個, 文字Bがb個, 部分文字列ABがc個あることがわかっている。この区間を右にひとつだけ伸ばしたとき、この数はどのように変化するか?
答:
- 新たな文字が"A"であれば"A"の数が1増える
- 新たな文字が"B"であれば"B"の数が1増え、"AB"の数が"A"の数だけ増える
後者の「"AB"の数が"A"の数だけ増える」のところは「新たにBがひとつ後ろに付け加わったのであれば、前に出てきているAの分だけABが増える」という理屈である。
ではこれを少し変化させた以下のような問題だとどうなるか。
問: 区間[x, y)の文字Aの数, 文字Bの数, 部分文字列ABの数と、区間[y, z)の文字Aの数, 文字Bの数, 部分文字列ABの数がわかっている。この2つの区間をくっつけた区間[x, z)の文字Aの数, 文字Bの数, 部分文字列ABの数はどうなっているか?
答:
- [x, z)のAの数 = [x, y)のAの数 + [y, z)のAの数
- [x, z)のBの数 = [x, y)のBの数 + [y, z)のBの数
- [x, z)のABの数 = [x, y)のABの数 + [y, z)のABの数 + [x, y)のAの数 × [y, z)のBの数
足し算のところは元あったものをそのまま足しているだけなので問題ないと思う。また最後の掛け算も「後ろにBが付け加わると、前に出てきたAの分だけABが増える」という計算をしているだけなので、やっていることはさっきと同じである(というかそもそもさっきの問題はこの問題の特殊なパターンのひとつとみなすことができる)。
以上より、「左の区間と右の区間をあわせてひとつの区間にするとき、そこに含まれる部分文字列の数がどうなるか?」という問題に対する基本的なアイデアが得られた。上の例だと部分文字列が最大でも長さ2であるが、今回の問題のようにもう少し長いものが必要な場合でも同じことができる。たとえば左(s), 右(s)をそれぞれ左区間、右区間の部分文字列sの数であるとすると、両方をあわせた区間の部分文字列"DISC"の数は以下のように計算できる。
左("DISC") + 右("DISC") + 左("D") × 右("ISC") + 左("DI") × 右("SC") + 左("DIS") × 右("C")
このように左の区間の値と右の区間の値を合成することでより長い区間の結果を計算することができる。合成の際に変な演算をすることもないのでこれはセグメントツリーに乗せられそうだということがわかる。自分の実装では以下のような形のセグ木を作った。
- 各ノードは「Dの数, DIの数, DISの数, DISCの数, DISCOの数, Iの数, ISの数, ...」 のように"DISCO"のありえる連続部分文字列(15通り)についてそれぞれ何個あるかを配列とかで持っておく。
- 上の"DISC"の計算式の例のように左と右を合成したときどうなるかの計算式を15通りすべてに対して気合で書いてmerge関数を作り、セグ木の区間のマージをそれでやるようにする(コード参照)。
あとはこのセグ木に区間クエリを投げて、帰ってきた結果から"DISCO"の数にあたる部分を答えればよい。
実装上の細かい点として、この問題ではmodが232なので、C++であればいちいち剰余を取らずとも数をunsigned intで持っておくだけでよい(勝手にmod 232の値になる)。
感想
いきなりセグ木で殴る話をしてもあれなので自分の思考過程的な感じで書いてみたけど人が読んだらかなり意味不明な気がして不安になってきた それはともかくこの問題のおかげで結構良い順位取れたのでOKです
コード (C++)
※冒頭のコメントは「セグ木のノードに持たせるvectorにおいて、何番がどの部分文字列の数に対応しているか」をメモったもの(ひどい)
#include <bits/stdc++.h> using namespace std; #define REP(i,n) for (int i=0;i<(n);i++) #define REP2(i,m,n) for (int i=m;i<(n);i++) typedef long long ll; typedef long double ld; typedef unsigned int uint; typedef vector<int> VI; typedef vector<ll> VL; /* 00: D 01: DI 02: DIS 03: DISC 04: DISCO 05: I 06: IS 07: ISC 08: ISCO 09: S 10: SC 11: SCO 12: C 13: CO 14: O */ const int K = 15; vector<uint> merge(vector<uint> a, vector<uint> b) { vector<uint> ret = vector<uint>(K, 0); REP(i, K) ret[i] = a[i] + b[i]; ret[1] += a[0] * b[5]; ret[2] += a[0] * b[6] + a[1] * b[9]; ret[3] += a[0] * b[7] + a[1] * b[10] + a[2] * b[12]; ret[4] += a[0] * b[8] + a[1] * b[11] + a[2] * b[13] + a[3] * b[14]; ret[6] += a[5] * b[9]; ret[7] += a[5] * b[10] + a[6] * b[12]; ret[8] += a[5] * b[11] + a[6] * b[13] + a[7] * b[14]; ret[10] += a[9] * b[12]; ret[11] += a[9] * b[13] + a[10] * b[14]; ret[13] += a[12] * b[14]; return ret; } class SegmentTree { public: vector<vector<uint>> table; int size; int offset; SegmentTree(int n) { size = 1; while (size <= n) size <<= 1; size <<= 1; table = vector<vector<uint>>(size, vector<uint>(K, 0)); offset = size / 2; } void assign(int pos, vector<uint> val) { pos += offset; table[pos] = val; while (pos > 1) { pos /= 2; table[pos] = merge(table[pos*2], table[pos*2+1]); } } vector<uint> query(int l, int r) { return query(l, r, 1, 0, offset-1); } vector<uint> query(int l, int r, int i, int a, int b) { if (b < l || r < a) { return vector<uint>(K, 0); } else if (l <= a && b <= r) { return table[i]; } else { return merge(query(l, r, i*2, a, (a+b)/2), query(l, r, i*2+1, (a+b)/2+1, b)); } } }; string S; int N, Q; int L[1010101]; int R[1010101]; void solve() { cin >> S >> Q; N = (int)S.size(); SegmentTree st = SegmentTree(N); REP(i, N) { vector<uint> v = vector<uint>(K, 0); if (S[i] == 'D') v[0] = 1; else if (S[i] == 'I') v[5] = 1; else if (S[i] == 'S') v[9] = 1; else if (S[i] == 'C') v[12] = 1; else v[14] = 1; st.assign(i, v); } while (Q--) { int l, r; cin >> l >> r; --l, --r; vector<uint> v = st.query(l, r); cout << v[4] << "\n"; } } int main() { cin.tie(0); ios::sync_with_stdio(false); solve(); }