Untitled

mail@pastecode.io avatar
unknown
plain_text
13 days ago
3.2 kB
4
Indexable
Never
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2")

#include <bits/stdc++.h>

#define pb push_back
#define f first
#define s second

using uint = unsigned int;

using namespace std;

const int N = 2e5 + 5;

vector < int > v[N];
int n, q, timer, in[N], out[N], O[N], dep[N], c[N], LCA[N], P[N][18];
uint h[N];

bool is_anc(int a, int b) {
    return in[a] <= in[b] && out[b] <= out[a];
}

int lca(int a, int b) {
    if (dep[a] > dep[b]) swap(a, b);
    if (is_anc(a, b)) return a;
    for (int i = 17; i >= 0; i--) {
        if (!is_anc(P[a][i], b)) a = P[a][i];
    }
    return P[a][0];
}

void dfs(int x, int p) {
    O[timer] = x;
    in[x] = timer++;
    P[x][0] = p;
    for (int j = 1; j < 18; j++) {
        P[x][j] = P[P[x][j - 1]][j - 1];
    }
    dep[x] = dep[p] + 1;
    for (int i = 0; i < v[x].size(); i++) {
        int to = v[x][i];
        if (to != p) dfs(to, x);
    }
    O[timer] = x;
    out[x] = timer++;
}

uint cur_ans, contr[N], total[N], sum_sq[N], sum[N], fr[N], ANS[N];

inline void update_ans(int i) {
    if (!i) {
        cur_ans -= contr[i];
        contr[i] = total[i] * sum_sq[i] - sum[i] * sum[i];
        cur_ans += contr[i];
    } else {
        cur_ans += contr[i];
        contr[i] = total[i] * sum_sq[i] - sum[i] * sum[i];
        cur_ans -= contr[i];
    }
}

inline void updateAdd(int x, int i) {
    sum[i] += h[x];
    sum_sq[i] += h[x] * h[x];
    total[i]++;
    update_ans(i);
}

inline void updateDec(int x, int i) {
    sum[i] -= h[x];
    sum_sq[i] -= h[x] * h[x];
    total[i]--;
    update_ans(i);
}

// add node x
inline void ins(int x) {
    updateAdd(x, 0);
    updateAdd(x, c[x]);
}

inline void del(int x) {
    updateDec(x, 0);
    updateDec(x, c[x]);
}

inline void shift(int x) {
    x = O[x];
    fr[x] ^= 1;
    if (!fr[x]) {
        del(x);
    } else {
        ins(x);
    }
}

main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
    }

    for (int i = 1; i <= n; i++) {
        cin >> h[i];
    }

    for (int i = 1; i < n; i++) {
        int a, b;
        cin >> a >> b;
        v[a].pb(b);
        v[b].pb(a);
    }

    dfs(1, 1);
    vector < pair < pair < int , int > , int > > Q;
    for (int i = 1; i <= q; i++) {
        int l, r;
        cin >> l >> r;
        LCA[i] = lca(l, r);
        Q.pb({{out[l], in[r]}, i});
    }

    int sq = min(2 * (int)sqrt(n), 333);

    sort(Q.begin(), Q.end(), [&](auto x, auto y){
        if (x.f.s / sq != y.f.s / sq){
            return x.f.s < y.f.s;
        }
        if ((x.f.s / sq) & 1) return x.f.f < y.f.f;
        return x.f.f > y.f.f;
    });

    int L = 0, R = -1;
    for (int i = 0; i < Q.size(); i++) {
        int id = Q[i].s;
        int l = Q[i].f.f;
        int r = Q[i].f.s;

        while (R < r) shift(++R); // not really
        while (l < L) shift(--L);
        while (r < R) shift(R--);
        while (L < l) shift(L++);

        ins(LCA[id]);
        ANS[id] = cur_ans;
        del(LCA[id]);
    }

    for (int i = 1; i <= q; i++) {
        cout << ANS[i] << "\n";
    }
}
/*

5 2
1 2 1 2 1
5 3 7 2 6
1 2
1 3
3 4
3 5
1 5
2 4

*/

Leave a Comment