CS Academy Round #42 (Div. 2 only) E - Xor Submatrix

https://csacademy.com/contest/round-42/task/xor-submatrix/

問題概要

N要素の数列UとM要素の数列Vが与えられる。これらを用いて,  A_{i, j} = V_i xor U_j となるようなN*Mの行列Aを構成する。Aから任意のサイズの部分行列を取り出し全要素のxorを取った時に得られる最大の値を求めよ。

1 <= N, M <= 1000, 0 <= V_i, U_i <= 229

解法

①部分行列の行または列が偶数のとき

同じもの2つで打ち消し合うというxorの性質上、部分行列を偶数行取った場合にはVの要素は関係なくなり、Uの部分列のxorを取っているのと同じことになる。偶数列取った場合も同じ。

②部分行列の行と列がともに奇数のとき

この場合も打消しを考えると結局Uの部分列(長さ奇数)とVの部分列(長さ奇数)のxorになる。

以上をまとめると求めたい値の候補は以下になる

  1. Uの偶数長の部分列のxor和
  2. Vの偶数長の部分列のxor和
  3. (Uの奇数長の部分列のxor和)xor(Vの奇数長の部分列のxor和)

このうち1, 2はO(N2)で列挙できる。3は単純に見ようとするとO(N4)かかってしまうので工夫が必要になる。

ここでxorを大きくするためにはどうすればよいかを考えると、上位の桁のbitがなるべく立つようにしていけばよいことがわかる。つまり集合Aと値xが与えられたとき、(Aの任意の1要素 xor x)の値を最大化するようなAの要素は、bitを上から見ていけば貪欲に選択できるということになる。

これを効率よく行うために、自分は左の子/右の子が各桁の0/1に対応するような2分木を作った。これである値がクエリとして与えられたとき、木に存在する値の中でxorを最大にするものがlog(集合の要素数)で取ってこれるようになる。この木に「Vの奇数長の部分列」によってできるすべてのxor和を詰めておいて、「Uの奇数長の部分列のxor和」をクエリとして投げれば上の箇条書きの3番の値もO(N2 logN)で全部見ることができる。

感想

2分木を自前のクラスで一生懸命作ってたらめちゃくちゃバグらせてしまった。vectorで十分だったかもしれない

コード (C++)

#include <bits/stdc++.h>
using namespace std;
#define REP(i,n) for (int i=0;i<(n);i++)
#define REP2(i,m,n) for (int i=m;i<(n);i++)
typedef long long ll;

class BinaryNode {
public:
    int data = -1;
    BinaryNode* left = nullptr;
    BinaryNode* right = nullptr;
};


int N, M;
int A[2020];
int B[2020];

BinaryNode* construct(vector<int> p, int k) {
    auto bn = new BinaryNode;
    if (p.size() == 1) {
        bn->data = p[0];
        return bn;
    }
    vector<int> l = vector<int>(0), r = vector<int>(0);
    for (auto pp: p) pp & (1 << k) ? l.push_back(pp) : r.push_back(pp);
    if (l.size() > 0) bn->left = construct(l, k-1);
    if (r.size() > 0) bn->right = construct(r, k-1);
    return bn;
}

int search(int x, int k, BinaryNode* bn) {
    if (bn->data != -1) return bn->data;
    if (bn->left == nullptr) return search(x, k - 1, bn->right);
    if (bn->right == nullptr) return search(x, k - 1, bn->left);
    if (x & (1 << k)) return search(x, k - 1, bn->right);
    return search(x, k - 1, bn->left);
}

int main() {
    cin >> N >> M;
    REP(i, N) cin >> A[i];
    REP(i, M) cin >> B[i];
    
    vector<int> hoge, fuga;
    int ans = 0;

    REP2(len, 1, M+1) {
        int tmp = 0;
        REP(i, len) tmp ^= B[i];
        if (len % 2 == 0) ans = max(ans, tmp);
        if (len % 2 == 1) fuga.push_back(tmp);
        REP2(i, len, M) {
            tmp ^= B[i-len];
            tmp ^= B[i];
            if (len % 2 == 0) ans = max(ans, tmp);
            if (len % 2 == 1) fuga.push_back(tmp);
        }
    }

    REP2(len, 1, N+1) {
        int tmp = 0;
        REP(i, len) tmp ^= A[i];
        if (len % 2 == 0) ans = max(ans, tmp);
        if (len % 2 == 1) hoge.push_back(tmp);
        REP2(i, len, N) {
            tmp ^= A[i-len];
            tmp ^= A[i];
            if (len % 2 == 0) ans = max(ans, tmp);
            if (len % 2 == 1) hoge.push_back(tmp);
        }
    }


    sort(hoge.begin(), hoge.end());
    sort(fuga.begin(), fuga.end());
    hoge.erase(unique(hoge.begin(), hoge.end()), hoge.end());
    fuga.erase(unique(fuga.begin(), fuga.end()), fuga.end());
    auto at = construct(hoge, 30);
    auto bt = construct(fuga, 30);

    for (auto a: hoge) ans = max(ans, a ^ search(a, 30, bt));
    for (auto b: fuga) ans = max(ans, b ^ search(b, 30, at));
    
    cout << ans << endl;
}