天下一プログラマーコンテスト2016本戦 C - たんごたくさん

https://atcoder.jp/contests/tenka1-2016-final/tasks/tenka1_2016_final_c

問題概要

文字列SとM要素の文字列の集合Pが与えられる。SからPの要素を取り除くと、取り除いた要素に対応する点数Wが与えられる。任意の回数この操作を行ったときの、得られる点数の最大値を求めよ

|S| <= 2×105

M <= 5000

|P_i| <= 200

解法

Pの各要素がSのどこに一致するかさえわかれば、以下のような単純なDPで解ける。

dp(i) = Sのi文字目までで得られる最大の点数

遷移については、Sの区間[i, i+k]に一致するようなPがある場合、dp(i+k) = dp(i)+そのPの点数 という方になる。

問題はPの各要素がSとどこで一致するか求めることである。これを本当に単純にやると(Sの文字数)×(Pの要素数)×(Piの最大文字数)かかる。これでは1011とかになって明らかに間に合わない。

方法のひとつはTrie木を使うことである。Trie木の特徴の一つとして、「あるひとつの文字列に対して複数の文字列をマッチさせる操作」が高速に行えるという点がある。今回の問題に当てはめていうと、木の構築を(Pの要素数)×(Piの最大文字数)で行え、S[i]を始点とした際にマッチする単語の列挙を(P_iの最大文字数)で行える。ゆえにDPの前にTrie木を構築しておけば、DPの計算は(Sの文字数)×(P_iの最大文字数)で済む。これは107のオーダーに収まるので十分間に合う。

感想

ローリングハッシュで109をゴリ押そうとしたら全然通らなかった

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, std.bitmanip;

void main() {
    auto S = readln.chomp;
    auto M = readln.chomp.to!int;
    auto P = M.iota.map!(_ => readln.chomp).array;
    auto W = M.iota.map!(_ => readln.chomp.to!int).array;
    auto N = S.length.to!int;

    auto trie = new Trie;
    foreach (i; 0..M) {
        trie.add_word(P[i], W[i]);
    }

    auto dp = new int[](N+1);
    foreach (i; 0..N) {
        dp[i+1] = max(dp[i+1], dp[i]);
        auto x = trie.search(S[i..min(i+200, N)]);
        foreach (j; 1..x.length) {
            dp[i+j] = max(dp[i+j], dp[i] + x[j]);
        }
    }

    dp[N].writeln;
}


class Node {
    int w;
    Node[char] children;

    this (int w) {
        this.w = w;
    }
}


class Trie {
    Node root;
    
    this() {
        root = new Node(0);
    }

    void add_word(string s, int w) {
        auto cur = root;
        foreach (c; s) {
            if (c !in cur.children) {
                cur.children[c] = new Node(0);
            }
            cur = cur.children[c];
        }
        cur.w = w;
    }

    int[] search(string s) {
        int[] ret = [0];
        auto cur = root;
        foreach (c; s) {
            if (c !in cur.children) {
                break;
            }
            cur = cur.children[c];
            ret ~= cur.w;
        }
        return ret;
    }
}