yukicoder No.776 A Simple RMQ Problem

https://yukicoder.me/problems/no/776

問題概要

N要素の数列Aに対して以下の2種類のクエリがQ個飛んでくるので順次処理せよ。

  • set(i, x): Ai の値を x に変更する。
  • max(l1, l2, r1, r2): l1 <= l <= l2, r1 <= r <= r2, l <= r を満たすすべての整数の組l, rにおける、Aの区間[l, r]の和の最大値を出力する。

N, Q <= 105

解法

まずmaxクエリでは、l1より左にはみ出したr1と、r2より右にはみ出したl2については考慮する必要がない(そこにr1やl2をとっても l <= r が絶対に満たせないため)。よって r1 = max(l1, r1), l2 = min(l2, r2) としてしまってよい。その上で以下の2パターンに場合分けして考える。

1. l2 < r1 のとき

このときlとrをどう取ろうがお互いに干渉することがない(どう取っても l <= r は満たされるため)。よってそれぞれを独立に決めることができる。具体的には以下の3つの区間についてそれぞれ値を求めて足し合わせればよい。

  • 区間[l1, l2]: 右端をl2に固定して左端を区間内で自由に決められるときの、最大の区間
  • 区間[l2+1, r1-1]: 区間全体の和(ここの区間は絶対に合計に入ってきてしまうので)
  • 区間[r1, r2]: 左端をr1に固定して右端を区間内で自由に決められるときの、最大の区間

2. l2 >= r1 のとき

lとrのとり方によってはl > rとなる可能性がある。以下の3パターンに分けて考える。

2-1. l < r1 となるようにlをとるとき

言い換えると[l1, r1-1]の区間内でlを取るとき。このときlをどう取ろうがrを追い越さないので先程と同じような値の決め方ができる。つまり

  • 区間[l1, r1-1]: 右端をr1-1に固定して左端を区間内で自由に決められるときの、最大の区間
  • 区間[r1, r2]: 左端をr1に固定して右端を区間内で自由に決められるときの、最大の区間

の2つを足し合わせればよい。

2-2. l2 < r となるようにrをとるとき

これも上と同じ。一応書くと

  • 区間[l1, l2]: 右端をl2に固定して左端を区間内で自由に決められるときの、最大の区間
  • 区間[l2+1, r2]: 左端をl2+1に固定して右端を区間内で自由に決められるときの、最大の区間

の2つを足し合わせる。

2-3. 上記以外のとき

言い換えるとlもrも[r1, l2]の範囲で取るとき。このときはこの区間の中の部分区間のうち、最大の和を求める必要がある。

セグ木に乗せる

以上より、maxクエリを処理するためには、ある区間[l, r]について以下の4つの情報がわかればよいということがわかった。

  1. sum(l, r): 区間[l, r]の単純な区間
  2. lmax(l, r): 区間[l, r]において左端をlに固定して右端を区間内で自由に決められるときの、最大の区間
  3. rmax(l, r): 区間[l, r]において右端をrに固定して左端を区間内で自由に決められるときの、最大の区間
  4. allmax(l, r): 区間[l, r]のすべての部分区間のうちの、区間和の最大値

で、これらはセグ木に乗せてうまいこと処理することができる。具体的には以下のように隣り合う2区間のマージをやっていく(以下ではマージする左側の区間を左区間、右側を右区間と便宜的に呼ぶ)。

  1. sum(l, r): これは普通のセグ木でも扱うような単純な区間和なので普通に足すだけ
  2. lmax(l, r): あらたにlmaxとなりうるのは「左区間のlmax」もしくは「左区間のsum + 右区間のlmax」のどちらかなので、これらのmaxをとる。
  3. rmax(l, r): 同上。
  4. allmax(l, r): あらたにallmaxとなりうるのは「左区間のallmax」もしくは「右区間のallmax」もしくは「左区間のrmax + 右区間のlmax」のどれかなので、これらのmaxをとる。

以上でmaxクエリの処理に必要なセグ木を作ることができる。更新も一点更新のみなので普通にやることができる。あとは上の考察通りにセグ木にクエリを投げて計算していけばよい。

感想

自力で解けなかったが、セグ木を組み立てるパズルみたいで楽しかった

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, core.stdc.string;

immutable long MOD = 10^^9 + 7;
immutable long INF = 1L << 59;

alias Node = Tuple!(long, "sum", long, "lmax", long, "rmax", long, "allmax");

Node merge(Node l, Node r) {
    long lmax = max(l.lmax, l.sum + r.lmax);
    long rmax = max(r.rmax, r.sum + l.rmax);
    long allmax = max(l.allmax, r.allmax, l.rmax + r.lmax);
    return Node(l.sum + r.sum, lmax, rmax, allmax);
}

class SegmentTree(T, alias op, T e) {
    T[] table;
    int size;
    int offset;

    this(int n) {
        size = 1;
        while (size <= n) size <<= 1;
        size <<= 1;
        table = new T[](size);
        fill(table, e);
        offset = size / 2;
    }

    void assign(int pos, T val) {
        pos += offset;
        table[pos] = val;
        while (pos > 1) {
            pos /= 2;
            table[pos] = op(table[pos*2], table[pos*2+1]);
        }
    }

    T query(int l, int r) {
        if (r < l) return e;
        return query(l, r, 1, 0, offset-1);
    }

    T query(int l, int r, int i, int a, int b) {
        if (b < l || r < a) {
            return e;
        } else if (l <= a && b <= r) {
            return table[i];
        } else {
            return op(query(l, r, i*2, a, (a+b)/2), query(l, r, i*2+1, (a+b)/2+1, b));
        }
    }
}


void main() {
    auto s = readln.split.map!(to!int);
    auto N = s[0];
    auto Q = s[1];
    auto A = readln.split.map!(to!long).array;

    auto st = new SegmentTree!(Node, merge, Node(0L, -INF, -INF, -INF))(N);
    foreach (i; 0..N) st.assign(i, Node(A[i], A[i], A[i], A[i]));

    while (Q--) {
        auto q = readln.split;
        if (q[0] == "set") {
            auto i = q[1].to!int - 1;
            auto x = q[2].to!long;
            st.assign(i, Node(x, x, x, x));
        } else {
            auto l1 = q[1].to!int - 1;
            auto l2 = q[2].to!int - 1;
            auto r1 = q[3].to!int - 1;
            auto r2 = q[4].to!int - 1;
            l2 = min(l2, r2);
            r1 = max(l1, r1);
            long ans = -INF;
            if (l2 < r1) {
                auto v1 = st.query(l1, l2).rmax;
                auto v2 = st.query(l2+1, r1-1).sum;
                auto v3 = st.query(r1, r2).lmax;
                ans = v1 + v2 + v3;
            } else {
                if (l1 < r1) {
                    auto v1 = st.query(l1, r1-1).rmax;
                    auto v2 = st.query(r1, r2).lmax;
                    ans = max(ans, v1 + v2);
                }
                if (l2 < r2) {
                    auto v1 = st.query(l1, l2).rmax;
                    auto v2 = st.query(l2+1, r2).lmax;
                    ans = max(ans, v1 + v2);
                }
                auto hoge = st.query(r1, l2);
                ans = max(ans, st.query(r1, l2).allmax);
            }
            ans.writeln;
        }
    }
}