AtCoder Grand Contest 019 C - Fountain Walk

http://agc019.contest.atcoder.jp/tasks/agc019_c

問題概要

東西方向に伸びる108本の道と、南北方向に伸びる108本の道がある。それぞれの道はすべて100メートルごとに等間隔で並んでおり、0から108 -1の番号がついている。南北方向のx番目の道と東西方向のy番目の道が交差している点を(x, y)のように表す。

いまN個の噴水があり、これらはN個の異なる交差点にそれぞれ置かれている。噴水は交差点を中心にした半径10メートルの円形である。また東西方向・南北方向いずれの道にも噴水が2つ以上置かれることはない。

人が通行できるのはそれぞれの道と噴水の外周のみである。なお噴水の内部は通行できない。スタートの交差点とゴールの交差点が与えられるので最短の通行距離を求めよ。

N <= 2 × 105

解法

曲がる場合は普通の道を直角に曲がるより噴水を経由した方が得である。ただしその得はあまり大きくないので、噴水を通るためにあえて遠回りするのはかえって損になる。つまり基本的には南北・東西に無駄なく移動しながら、できるだけ多くの噴水を通るのが最短となる。

通る噴水の数を最大化するためにはどうすればいいか? 例えばスタートが左下でゴールが右上の場合、噴水も左下から右上に並ぶようにとっていく必要がある。この最大値は最長増加部分列の考え方で計算可能。あとは単純な距離から噴水経由によって得した距離を引けばよい。

ただしひとつだけ厄介なコーナーケースがある。それは全部の行あるいは全部の列で噴水を通れる場合である。そのような場合に必ず噴水を通るとすると、最後の噴水では曲がれないのでショートカットにならない。普通に直線を通るところで半円を回る形になるのでその分距離が増える。

感想

なんとか本番で通せた(5WA)

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop;

alias Tuple!(long, "x", long, "y") Point;
immutable long MAX = 10^^8;

void main() {
    auto s = readln.split.map!(to!long);
    auto x1 = s[0], y1 = s[1], x2 = s[2], y2 = s[3];
    auto N = readln.chomp.to!int;
    auto P = new Point[](N);
    foreach (i; 0..N) {
        s = readln.split.map!(to!long);
        P[i] = Point(s[0], s[1]);
    }

    if (x1 > x2) swap(x1, x2), swap(y1, y2);
    if (y1 > y2) {
        y1 = MAX - y1;
        y2 = MAX - y2;
        foreach (i; 0..N) P[i].y = MAX - P[i].y;
    }

    if (x1 == x2) {
        real ans = (y2 - y1) * 100;
        bool hoge;
        foreach (i; 0..N) if (P[i].x == x1 && P[i].y >= y1 && P[i].y <= y2) hoge = true;
        if (hoge) {
            ans -= 20;
            ans += 10 * PI;
        }
        writefln("%.12f", ans);
        return;
    } else if (y1 == y2) {
        real ans = (x2 - x1) * 100;
        bool hoge;
        foreach (i; 0..N) if (P[i].y == y1 && P[i].x >= x1 && P[i].x <= x2) hoge = true;
        if (hoge) {
            ans -= 20;
            ans += 10 * PI;
        }
        writefln("%.12f", ans);
        return;
    }


    Point[] Q;
    foreach (i; 0..N) {
        if (P[i].x >= x1 && P[i].x <= x2 && P[i].y >= y1 && P[i].y <= y2) Q ~= P[i];
    }

    auto M = Q.length.to!int;
    Q.sort!"a[0] < b[0]"();
    auto st = new SegmentTree(M);
    long[] hoge;
    foreach (i; 0..M) hoge ~= Q[i].y;
    auto fuga = hoge.sort().uniq.array;
    int[long] comp;
    foreach (i; 0..fuga.length) comp[fuga[i.to!int]] = i.to!int;

    foreach (i; 0..M) {
        auto xxx = st.sum(0, comp[Q[i].y]) + 1;
        st.assign(comp[Q[i].y], xxx);
    }

    real ans = (x2 - x1) * 100 + (y2 - y1) * 100;
    auto xxx = st.sum(0, M-1).to!long;
    ans -= 20 * xxx;
    ans += 5 * xxx * PI;

    if (x2 - x1 + 1 == xxx || y2 - y1 + 1 == xxx) {
        ans += 5 * PI;
    } 

    writefln("%.12f", ans);
}

class SegmentTree {
    int[] table;
    int size;

    this(int n) {
        assert(bsr(n) < 29);
        size = 1 << (bsr(n) + 2);
        table = new int[](size);
    }

    void assign(int pos, int num) {
        return assign(pos, num, 0, 0, size/2-1);
    }

    void assign(int pos, int num, int i, int left, int right) {
        if (left == right) {
            table[i] = num;
            return;
        }
        auto mid = (left + right) / 2;
        if (pos <= mid)
            add(pos, num, i*2+1, left, mid);
        else
            add(pos, num, i*2+2, mid+1, right);
        table[i] = max(table[i*2+1], table[i*2+2]);
    }

    void add(int pos, int num) {
        return add(pos, num, 0, 0, size/2-1);
    }

    void add(int pos, int num, int i, int left, int right) {
        if (left == right) {
            table[i] += num;
            return;
        }
        auto mid = (left + right) / 2;
        if (pos <= mid)
            add(pos, num, i*2+1, left, mid);
        else
            add(pos, num, i*2+2, mid+1, right);
        table[i] = max(table[i*2+1], table[i*2+2]);
    }

    int sum(int pl, int pr) {
        return sum(pl, pr, 0, 0, size/2-1);
    }

    int sum(int pl, int pr, int i, int left, int right) {
        if (pl > right || pr < left)
            return 0;
        else if (pl <= left && right <= pr)
            return table[i];
        else
            return
                max(sum(pl, pr, i*2+1, left, (left+right)/2),
                    sum(pl, pr, i*2+2, (left+right)/2+1, right));
    }
}