AtCoder Regular Contest 067: F - Yakiniku Restaurants

問題概要

N軒の焼肉屋が横一直線に並んでおり、それぞれの間の距離は数列Aで表される。またM枚のチケットを1枚ずつ持っており、チケットjを焼肉屋iがある地点で使うことでB(i, j)の幸福度を得ることができる。好きな地点からスタートして自由に移動してチケットを使うとき、(得られる幸福度 - 移動距離の和) の最大値を求めよ。

N <= 5×103

M <= 200

解法

訪れる区間を固定して考える。訪れる中で最も左にある焼肉屋をl, 最も右にある焼肉屋をrとすると、求めるスコアは(各チケットごとの区間[l, r]での最大値)の和 -(lからrまでの距離)となる(単純に端から端へ移動しながら区間内の一番高いところでチケットを使えばいいので)。

愚直に上の値を求めようとすると「各チケットごとの区間[l, r]での最大値」の部分がネックになる。区間の個数はO(N2)でありチケットの種類はM個あるため一個一個求めていく方法ではどう頑張ってもO(MN2)はかかってしまうためである。なのでここを何とかして高速にやる必要がある。以下ではこの値を max[l][r] = 区間[l, r]におけるそのチケットの最大値 とおく。

ここであるチケットがとりうる最大値を考えてみる。もしあるチケットが焼肉屋xで最大値vを取るのであれば、xを含むどの区間[l, r]においてもmax[l][r]は全部vになるはずである。この手続で「xを含む区間」の最大値は全部わかるので、次はxを含まない区間[1, x)と(x, N]で同じ手続きを行う。……というように区間を分割しながら再帰的に同様の手続きを続けていくと、最終的に全部の区間の最大値を埋めることができる。そしてこの一連の手続きは必ずN回で終了する。なぜなら一回の手続で区間内から必ずひとつインデックスが消えるからである。

このように上の手続自体はN回で済むので、1回の手続にかかる計算量を抑えることができればmax[l][r]のテーブルを埋めることができる。まず必要な処理は「指定された区間の中で最大の値をもつインデックスを取得する」で、これはおなじみのRMQを用いれば1回につきO(logN)で済む。難しいのがもうひとつの処理「区間[l, r]で最大値vをとるインデックスxについて、[l, r]内でxを含むすべての区間に値vを割り当てる」だが、これはmax[l][r]のテーブルを2次元的に考えることでうまくやることができる。実はこの2次元テーブルにおいては、「区間[i, j]においてインデックスxを含むようなすべての区間」は長方形の形をとる。

f:id:fluffyowl:20180525192744p:plain

上図は区間[0, 6]で最大値をとるインデックスが仮に4だった場合を示している。ここで「インデックス4を含むすべての区間」は図で青く塗ってある部分に等しい。つまりこの部分の長方形に一様にインデックス4での値を足すことができればよく、それを効率的に行えるアルゴリズムが存在する。いもす法である。これに関しては御本人の解説がそれそのものなので参照されたい。とにかくこれを使うと最後にテーブルの累積和をとる部分でO(N2)かかるが途中の一回一回の長方形への足し算はO(1)で可能になる。よってこれでひとつのチケットにおけるすべての区間での最大値をO(N2)で求めO(1)で参照できるようになった。さらに各チケットのスコアはただ足し算すればいいだけなので、ひとつの2次元テーブルで同じ加算をいっしょくたにやってしまっていい。これで全部のチケットを合わせた結果も同様の計算量で求まることになる。

求めたい値は (各チケットごとの区間[l, r]での最大値)の和 -(lからrまでの距離)であった。これの前者は今まで述べてきたとおりO(1)で参照することができるようになり、後者は単純な累積和でこれも参照O(1)なので、区間をO(N2)で全探索しても十分間に合うことになる。

感想

どの解説漁っても突然いもす法が出てくるので????(??????????)になってたがpekempeyさんの記事の表とにらめっこしてやっと理解できた。なんかすごい汎用性ありそうなテクだ

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, std.bitmanip;

void main() {
    auto s = readln.split.map!(to!int);
    auto N = s[0];
    auto M = s[1];
    auto A = readln.split.map!(to!long).array;
    auto B = N.iota.map!(_ => readln.split.map!(to!long).array).array;
    auto C = new long[](N);
    foreach (i; 0..N-1) C[i+1] = C[i] + A[i];

    auto table = new long[][](N+1, N+1);
    auto st = new SegmentTree!(Tuple!(long, int), (a, b) => max(a, b), tuple(-(1L<<59), -1))(N);
    auto stack = new Tuple!(int, int)[](N+10);

    foreach (m; 0..M) {
        foreach (i; 0..N) st.assign(i, tuple(B[i][m], i));
        stack[0] = tuple(0, N-1);
        int sp = 0;
        while (sp >= 0) {
            auto l = stack[sp][0];
            auto r = stack[sp][1];
            sp -= 1;
            auto t = st.query(l, r);
            auto v = t[0];
            auto p = t[1];
            table[l][p] += v;
            table[l][r+1] -= v;
            table[p+1][p] -= v;
            table[p+1][r+1] += v;
            if (l <= p-1) stack[++sp] = tuple(l, p-1);
            if (r >= p+1) stack[++sp] = tuple(p+1, r);
        }
    }

    long ans = 0;
    foreach (i; 0..N) foreach (j; 0..N) table[i][j+1] += table[i][j];
    foreach (j; 0..N) foreach (i; 0..N) table[i+1][j] += table[i][j];
    foreach (i; 0..N) foreach (j; i..N) ans = max(ans, table[i][j] - C[j] + C[i]);
    ans.writeln;
}


class SegmentTree(T, alias op, T e) {
    T[] table;
    int size;
    int offset;

    this(int n) {
        assert(bsr(n) < 29);
        size = 1 << (bsr(n) + 2);
        table = new T[](size);
        fill(table, e);
        offset = size / 2;
    }

    void assign(int pos, T val) {
        pos += offset;
        table[pos] = val;
        while (pos > 1) {
            pos /= 2;
            table[pos] = op(table[pos*2], table[pos*2+1]);
        }
    }

    T query(int l, int r) {
        return query(l, r, 1, 0, offset-1);
    }

    T query(int l, int r, int i, int a, int b) {
        if (b < l || r < a) {
            return e;
        } else if (l <= a && b <= r) {
            return table[i];
        } else {
            return op(query(l, r, i*2, a, (a+b)/2), query(l, r, i*2+1, (a+b)/2+1, b));
        }
    }
}