CODE FESTIVAL 2017 Elimination Tournament Round 1: D - Ancient Tree Record

https://atcoder.jp/contests/cf17-tournament-round1-open/tasks/asaporo2_d

問題概要

N頂点の木とN要素の数列Sが与えられる。木のすべての辺に整数の長さを割り当てて、Siが木の頂点iとすべての頂点との最短距離の和と等しくなるようにせよ。与えられる入力においてはそのような割り当て方がただひとつ存在することが保証される。

N <= 105

解法

ある頂点の組(a, b)を結ぶ辺eに注目すると、「Saの計算過程においてe以外の辺を通る回数」と「Sbの計算過程においてe以外の辺を通る回数」は等しい。また「Saの計算過程においてeを通る回数」はaから見てbの向こう側にある頂点の数に等しく、「Sbの計算過程においてeを通る回数」はbから見てaの向こう側にある頂点の数に等しい。よってSaとSbの差を取ったあと、eを通る回数の差でそれを割ればeの長さが一意に決定できる。

一点例外として、辺eの繋ぐ頂点(a, b)が両方とも木の中心である場合は注意が必要で、この場合SaもSbもeを通る回数も同じであるため上の方法では計算できない。しかしこの場合にも他のすべての辺は上の方法で求めることができるので、先にそれらを求めておけば残る辺eもそこから計算できる。

感想

なんか解けた、嬉しい 木の問題は辺を通る回数に着目するのが重要テクのひとつっぽいですね

コード (D言語)

import std.stdio, std.array, std.string, std.conv, std.algorithm;
import std.typecons, std.range, std.random, std.math, std.container;
import std.numeric, std.bigint, core.bitop, std.bitmanip;

void main() {
    auto N = readln.chomp.to!int;
    auto E = new Tuple!(int, int)[](N-1);
    auto G = new Tuple!(int, int)[][](N);
    foreach (i; 0..N-1) {
        auto s = readln.split.map!(to!int);
        G[s[0]-1] ~= tuple(s[1]-1, i);
        G[s[1]-1] ~= tuple(s[0]-1, i);
        E[i] = tuple(s[0]-1, s[1]-1);
    }
    auto S = readln.split.map!(to!long).array;

    if (N == 2) {
        writeln(S[0]);
        return;
    }
    
    auto depth = new int[](N);
    auto sub = new int[](N);
    auto ans = new long[](N-1);
    
    int dfs(int n, int p, int d) {
        depth[n] = d;
        foreach (m; G[n]) if (m[0] != p) sub[n] += dfs(m[0], n, d+1);
        return sub[n] + 1;
    }
    dfs(0, -1, 0);

    long dfs2(int n, int p, long d) {
        long ret = d;
        foreach (m; G[n]) if (m[0] != p) ret += dfs2(m[0], n, d+ans[m[1]]);
        return ret;
    }

    foreach (i, e; E.enumerate) {
        int a = e[0];
        int b = e[1];
        if (S[a] == S[b]) continue; 
        if (depth[a] > depth[b]) swap(a, b);
        long aa = sub[b];
        long bb = N - sub[b] - 2;
        ans[i] = (S[a] - S[b]) / (aa - bb);
    }

    foreach (i, e; E.enumerate) {
        int a = e[0];
        int b = e[1];
        if (S[a] != S[b]) continue;
        if (depth[a] > depth[b]) swap(a, b);
        long aa = dfs2(a, b, 0);
        long bb = dfs2(b, a, 0);
        long ss = aa + bb;
        ans[i] = (S[a] - ss) / (sub[b] + 1);
    }

    ans.each!writeln;
}