Untitled
unknown
plain_text
13 days ago
3.2 kB
4
Indexable
Never
#pragma GCC optimize("Ofast") #pragma GCC target("avx,avx2") #include <bits/stdc++.h> #define pb push_back #define f first #define s second using uint = unsigned int; using namespace std; const int N = 2e5 + 5; vector < int > v[N]; int n, q, timer, in[N], out[N], O[N], dep[N], c[N], LCA[N], P[N][18]; uint h[N]; bool is_anc(int a, int b) { return in[a] <= in[b] && out[b] <= out[a]; } int lca(int a, int b) { if (dep[a] > dep[b]) swap(a, b); if (is_anc(a, b)) return a; for (int i = 17; i >= 0; i--) { if (!is_anc(P[a][i], b)) a = P[a][i]; } return P[a][0]; } void dfs(int x, int p) { O[timer] = x; in[x] = timer++; P[x][0] = p; for (int j = 1; j < 18; j++) { P[x][j] = P[P[x][j - 1]][j - 1]; } dep[x] = dep[p] + 1; for (int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if (to != p) dfs(to, x); } O[timer] = x; out[x] = timer++; } uint cur_ans, contr[N], total[N], sum_sq[N], sum[N], fr[N], ANS[N]; inline void update_ans(int i) { if (!i) { cur_ans -= contr[i]; contr[i] = total[i] * sum_sq[i] - sum[i] * sum[i]; cur_ans += contr[i]; } else { cur_ans += contr[i]; contr[i] = total[i] * sum_sq[i] - sum[i] * sum[i]; cur_ans -= contr[i]; } } inline void updateAdd(int x, int i) { sum[i] += h[x]; sum_sq[i] += h[x] * h[x]; total[i]++; update_ans(i); } inline void updateDec(int x, int i) { sum[i] -= h[x]; sum_sq[i] -= h[x] * h[x]; total[i]--; update_ans(i); } // add node x inline void ins(int x) { updateAdd(x, 0); updateAdd(x, c[x]); } inline void del(int x) { updateDec(x, 0); updateDec(x, c[x]); } inline void shift(int x) { x = O[x]; fr[x] ^= 1; if (!fr[x]) { del(x); } else { ins(x); } } main() { ios::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL); cin >> n >> q; for (int i = 1; i <= n; i++) { cin >> c[i]; } for (int i = 1; i <= n; i++) { cin >> h[i]; } for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; v[a].pb(b); v[b].pb(a); } dfs(1, 1); vector < pair < pair < int , int > , int > > Q; for (int i = 1; i <= q; i++) { int l, r; cin >> l >> r; LCA[i] = lca(l, r); Q.pb({{out[l], in[r]}, i}); } int sq = min(2 * (int)sqrt(n), 333); sort(Q.begin(), Q.end(), [&](auto x, auto y){ if (x.f.s / sq != y.f.s / sq){ return x.f.s < y.f.s; } if ((x.f.s / sq) & 1) return x.f.f < y.f.f; return x.f.f > y.f.f; }); int L = 0, R = -1; for (int i = 0; i < Q.size(); i++) { int id = Q[i].s; int l = Q[i].f.f; int r = Q[i].f.s; while (R < r) shift(++R); // not really while (l < L) shift(--L); while (r < R) shift(R--); while (L < l) shift(L++); ins(LCA[id]); ANS[id] = cur_ans; del(LCA[id]); } for (int i = 1; i <= q; i++) { cout << ANS[i] << "\n"; } } /* 5 2 1 2 1 2 1 5 3 7 2 6 1 2 1 3 3 4 3 5 1 5 2 4 */
Leave a Comment