Untitled

mail@pastecode.io avatar
unknown
c_cpp
22 days ago
6.6 kB
2
Indexable
Never
#include <bits/stdc++.h>
using namespace std;

using i64 = long long;

struct Fenwick {
    int n = 0;
    vector<int> f;
    Fenwick(int n) : n(n) {
        f.assign(n, 0);
    }

    void build(vector<int> &val, vector<int> &u) {
        for (int i = 0; i < n; ++i) {
            f[i] = val[u[i]];
        }
        for (int i = 0; i < n; ++i) {
            if ((i | (i + 1)) < n) {
                f[(i | (i + 1))] += f[i];
            }
        }
    }

    int query(int i) {
        int res = 0;
        for(; i >= 0; i = (i & (i + 1)) - 1) {
            res += f[i];
        }
        return res;
    }

    int get(int l, int r) {
        return query(r) - query(l - 1);
    }
};

void solve() {
    int n, q;
    cin >> n >> q; 

    int a[n];
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    vector<vector<int>> adj(n);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    int timer = -1;
    int dep[n] {}, in[n], out[n], par[n];
    vector<vector<int>> layer(n);
    function<void(int, int)> dfs = [&](int u, int p) {
        in[u] = ++timer;
        par[u] = p;
        layer[dep[u]].push_back(u);

        for (auto v : adj[u]) {
            if (v == p) continue;
            dep[v] = dep[u] + 1;
            a[v] ^= a[u];
            dfs(v, u);
        }

        out[u] = timer;
    };
    dfs(0, -1);

    int const lg = 20;

    vector<vector<vector<int>>> pref(n);
    for (int i = 0; i < n; ++i) {
        pref[i].assign(layer[i].size(), vector<int>(lg, 0));
        for (int j = 0; j < layer[i].size(); ++j) {
            int u = layer[i][j];
            for (int k = 0; k < lg; ++k) {
                if (j) {
                    pref[i][j][k] = pref[i][j - 1][k];
                }
                pref[i][j][k] += (1 << k & a[u]) > 0;
            }
        }
    }
   

    vector<Fenwick> st(lg, Fenwick(n));

    vector<int> eulerNode(n);
    for (int u = 0; u < n; ++u) {
        eulerNode[in[u]] = u;
    }

    auto updateBit = [&]() {
        for (int j = 0; j < lg; ++j) {
            vector<int> val(n);
            for (int u = 0; u < n; ++u) {
                val[u] = (1 << j & a[u]) > 0;
            }
            st[j].build(val, eulerNode);
        }
    };
    updateBit();

    int const S = 350;

    int depthXor[n] {}, layerVis[n] {};
    int l[q], r[q], x[q], op[q];
    bool isFlipped[n][lg] {};
    for (int i = 0; i < q; ++i) {
        if (i % S == 0 && i) {
            for (int u = 0; u < n; ++u) {
                a[u] ^= depthXor[dep[u]];
            }
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < lg; ++j) {
                    if ((1 << j & depthXor[i])) {
                        isFlipped[i][j] = !isFlipped[i][j];
                    }
                }
            }
            memset(depthXor, 0, sizeof depthXor);
            updateBit();
        }
        cin >> op[i];
        if (op[i] == 1) {
            cin >> l[i] >> r[i] >> x[i];
            depthXor[l[i]] ^= x[i];
        } else {
            int u;
            cin >> u;
            --u;

            vector<int> reqBit(lg, 1);
            if (par[u] != -1) {
                int x = a[par[u]];
                x ^= depthXor[dep[par[u]]];
                for (int j = 0; j < lg; ++j) {
                    reqBit[j] = (1 << j & x) == 0;
                }
            }

            vector<int> cnt(lg);
            for (int j = 0; j < lg; ++j) {
                int res = st[j].get(in[u], out[u]);
                if (reqBit[j] == 0) {
                    res = out[u] - in[u] + 1 - res;
                }
                cnt[j] += res;
            }

            for (int prevQ = i - 1; prevQ >= max(0, i - S); --prevQ) {
                if (op[prevQ] == 1) {
                    int z = l[prevQ];
                    if (z < dep[u] || layerVis[z] || layer[z].size() == 0) {
                        continue;
                    }
                    layerVis[z] = 1;

                    int L = 0, R = layer[z].size() - 1;
                    while (L <= R) {
                        int mid = (L + R) / 2;
                        if (in[layer[z][mid]] >= in[u]) {
                            R = mid - 1;
                        } else {
                            L = mid + 1;
                        }
                    }
                    int prefStart = R + 1;

                    L = 0, R = layer[z].size() - 1;
                    while (L <= R) {
                        int mid = (L + R) / 2;
                        if (in[layer[z][mid]] <= out[u]) {
                            L = mid + 1;
                        } else {
                            R = mid - 1;
                        }
                    }
                    int prefEnd = L - 1;

                    if (prefStart <= prefEnd) {
                        for (int j = 0; j < lg; ++j) {
                            if ((1 << j & depthXor[z])) {
                                int on = pref[z][prefEnd][j] - (prefStart ? pref[z][prefStart - 1][j] : 0);
                                int off = prefEnd - prefStart + 1 - on;
                                if (isFlipped[z][j]) {
                                    swap(on, off);
                                }
                                if (reqBit[j]) {
                                    cnt[j] -= on;
                                    cnt[j] += off;
                                } else {
                                    cnt[j] -= off;
                                    cnt[j] += on;
                                }
                            }   
                        }
                    }
                }
            }
            for (int prevQ = i - 1; prevQ >= max(0, i - S); --prevQ) {
                if (op[prevQ] == 1) {
                    int z = l[prevQ];
                    layerVis[z] = 0;
                }
            }

            i64 ans = 0;
            for (int j = 0; j < lg; ++j) {
                ans += 1LL * cnt[j] * (1 << j);
            }

            cout << ans << "\n";
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int t = 1;
    cin >> t;

    while (t--) {
        solve();
    }

    return 0;
}
Leave a Comment