CS Academy Round #41 E - Candles

https://csacademy.com/contest/round-41/task/candles/

問題概要

N要素の数列HとM要素の数列Cが与えられる。以下のような操作をHに対して行っていく。

操作: i回目の操作においてHからC[i]個の正の要素を選び、すべてを1デクリメントする。

数列Hに正の要素がC[i]個未満しか存在しないか、あるいはiがMを超えた時点で操作を終了する。最大で何回の操作を行うことができるか求めよ。

1 <= N, M, max(H), max(C) <= 105

解法

最適な選び方自体は単純で、大きい順にC[i]個の要素を選んでデクリメントするだけ。問題は計算量で、「大きい順にC[i]個選ぶ」「それらに-1を加算する」の操作は愚直にやるとどちらも毎回O(N)かかる。

ここでHをソートしておくと「大きい順にC[i]個選ぶ」はO(1)で可能になる。またセグメントツリーでうまいことやると区間加算はO(logN)で可能になる。というわけで1回の操作はセグ木に突っ込んでおいたソート済の列の後ろC[i]個に対して-1を加算することでO(logN)で行える。ここで問題になるのは加算によってソートが崩れることである。例えば

1 2 3 3 3 4 5 に対してC[i] = 4のとき

1 2 3 3 3 4 5 にデクリメントをかけると

1 2 3 2 2 3 4 という風にソートが崩れる。

ここでソートが崩れる場所を観察すると、値の逆転が起こりうるのは「選ばれたC[i]個の中で最も小さい数値」の周りであることがわかる。これは1回につき必ず1ずつしか値が減らないという性質に依る。ゆえにデクリメントをかける際に「選ばれるC[i]個の数値のうち、もっとも小さいもの」を求めておいて、そこへのデクリメントのみ前から行う、という風にすればソート順を崩すことなく操作を行える。

1 2 3 3 4 4 5

1 2 3 3 3 4 5

1 2 2 2 3 3 4

この操作のためには、「選ばれるC[i]個の数値のうちもっとも小さいもの」の値をvとおくと「値がvであるもののうちもっとも左のindex」「値がvであるもののうちもっとも右のindex」の情報が必要になる。これは例えばセグメントツリーを上位ノードでminをとるようなタイプにしておくと、セグ木上で二分探索みたいなことができてO(logN)で可能。以上より全体としてO(MlogN)で答えが出せる。

感想

遅延評価タイプのセグ木をわかってなさすぎて実装が辛かった。本番中に正しい方針を思いつけていたところまでは良かった。

コード (C++)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define REP(i,n) for (int i=0;i<(n);i++)
#define REP2(i,m,n) for (int i=m;i<(n);i++)


ll N, M;
ll H[101010];
ll C[101010];


class LazySegmentTree {
public:
    vector<ll> table;
    vector<ll> lazy_;
    int size;
    ll defval;
 
    LazySegmentTree(int n, ll defval = 0) {
        size = 1;
        while (size < n) size *= 2;
        size *= 2;
        table = vector<ll>(size, defval);
        lazy_ = vector<ll>(size, defval);
        LazySegmentTree::defval = defval;
    }

    void assign(int a, ll num) {
        return assign(a, num, 0, 0, size/2-1);
    }
 
    void assign(int a, ll num, int i, int l, int r) {
        if (a < l || a > r) return;
        
        if (l == r) {
            table[i] = num;
            return;
        }
        
        assign(a, num, i*2+1, l, (l+r)/2);
        assign(a, num, i*2+2, (l+r)/2+1, r);
        table[i] = max(table[i*2+1] + lazy_[i*2+1],
                       table[i*2+2] + lazy_[i*2+2]);
    }

    void add(int a, int b, ll num) {
        return add(a, b, num, 0, 0, size/2-1);
    }
 
    void add(int a, int b, ll num, int i, int l, int r) {
        if (a > r || b < l) return;
        if (a <= l && r <= b) {
            lazy_[i] += num;
            return;
        }
        
        add(a, b, num, i*2+1, l, (l+r)/2);
        add(a, b, num, i*2+2, (l+r)/2+1, r);
        table[i] = max(table[i*2+1] + lazy_[i*2+1],
                       table[i*2+2] + lazy_[i*2+2]);
    }    
 
    ll query(int a, int b) {
        return query(a, b, 0, 0, size/2-1);
    }
 
    ll query(int a, int b, int i, int l, int r) {
        if (a > r || b < l) return defval;
        if (a <= l && r <= b) return table[i] + lazy_[i];
        
        ll vl = query(a, b, i*2+1, l, (l+r)/2);
        ll vr = query(a, b, i*2+2, (l+r)/2+1, r);
        return max(vl, vr) + lazy_[i];
    }

    ll search_left(ll num) {
        if (num > query(0, N-1)) return N;
        return search_left(num, 0, 0, size/2-1, 0);
    }

    ll search_left(ll num, int i, int l, int r, ll acm) {
        if (l == r) return l;
        acm += lazy_[i];
        ll vl = table[i*2+1] + lazy_[i*2+1] + acm;
        ll vr = table[i*2+2] + lazy_[i*2+2] + acm;
        if (vl >= num) return search_left(num, i*2+1, l, (l+r)/2, acm);
        else           return search_left(num, i*2+2, (l+r)/2+1, r, acm);
    }
};


int main() {
    cin >> N >> M;
    REP(i, N) cin >> H[i];
    REP(i, M) cin >> C[i];
    sort(H, H+N);
    LazySegmentTree st = LazySegmentTree(N, 0);
    REP(i, N) st.add(i, i, H[i]);

    ll ans = 0;
        
    REP(i, M) {
        ll x = N - C[i];
        if (x < 0) break;
        ll v = st.query(x, x);
        if (v == 0) break;

        ll l = st.search_left(v);
        ll r = st.search_left(v + 1);
        ll seg = r - x - 1;
        st.add(l, l + seg, -1);
        st.add(r, N - 1, -1);

        ans = i + 1;

    }

    cout << ans << endl;
}