FULL
unknown
c_cpp
2 years ago
9.2 kB
2
Indexable
Never
struct diam_dp { int farthest, diam; diam_dp() { farthest = -inf; diam = 0; } void add(int x) { farthest = max(x, farthest); diam = max(diam, farthest); } void add(diam_dp child_dp) { diam = max(diam, farthest + child_dp.farthest + 1); diam = max(diam, child_dp.diam); farthest = max(farthest, child_dp.farthest + 1); } }; struct k_max { pii a[5]; int k; k_max(int _k) { k = _k; for (int i = 0; i < _k; i++) a[i] = {0, 0}; } void add(int val, int label) { for (int i = 0; i < k; i++) { if (val >= a[i].f) { for (int j = k - 1; j > i; j--) a[j] = a[j - 1]; a[i] = {val, label}; break; } } } int get(vector<int> labels, int cnt = 1) { int res = 0; for (int i = 0; i < k && cnt; i++) { bool found = 0; for (auto label : labels) found |= (a[i].s == label); if (!found) res += a[i].f, cnt--; } return res; } }; struct fen { ll t[N]; fen() { for (int i = 0; i < N; i++) t[i] = 0; } ll sum(int r) { int result = 0; for (; r >= 0; r = (r & (r + 1)) - 1) result += t[r]; return result; } ll get(int l, int r) { if (l > r) return 0; return sum(r) - sum(l - 1); } void inc(int i, int val) { for (; i < N; i = (i | (i + 1))) t[i] += val; } }; vector<int> g[N]; int n; ll ans[N]; int A, B; vector<int> DIAM; bool in_DIAM[N]; int in_DIAM_pos[N]; int diam_parent[N]; diam_dp dp[N]; int self_val[N], left_val[N], right_val[N]; int sz[N]; int m; pair<int, pii> events[3 * N]; int depth[N], parent[N]; int get_farthest(int v, int p) { parent[v] = p; if (v == p) depth[v] = 0; else depth[v] = depth[p] + 1; int cur_farthest = v; for (auto to : g[v]) { if (to == p) continue; int u = get_farthest(to, v); if (depth[cur_farthest] < depth[u]) cur_farthest = u; } return cur_farthest; } void dfs(int v, int p) { if (v == p) diam_parent[v] = v; else diam_parent[v] = diam_parent[p]; dp[v].add(0); for (auto to : g[v]) { if (in_DIAM[to] || to == p) continue; dfs(to, v); dp[v].add(dp[to]); } } void calc_self(int v, int p, int k = 0) { sz[v] = 1; self_val[v] = max(self_val[v], dp[v].diam); self_val[v] = max(self_val[v], k); k_max maxs = k_max(2), kek = k_max(3); for (auto to : g[v]) { if (in_DIAM[to] || to == p) continue; maxs.add(max(dp[to].diam, dp[to].farthest + 1), to); kek.add(dp[to].farthest + 1, to); } for (auto to : g[v]) { if (in_DIAM[to] || to == p) continue; int kk = k; kk = max(kk, maxs.get({to})); kk = max(kk, kek.get({to}, 2)); calc_self(to, v, kk); sz[v] += sz[to]; } } void calc_left(int v, int p, int k) { left_val[v] = max(self_val[v], k); for (auto to : g[v]) { if (in_DIAM[to] || to == p) continue; calc_left(to, v, k); } } void calc_right(int v, int p, int k) { right_val[v] = max(self_val[v], k); for (auto to : g[v]) { if (in_DIAM[to] || to == p) continue; calc_right(to, v, k); } } vector<ll> solve(int _n, vector<pii> edges, vector<pii> queries) { n = _n; for (auto edge : edges) { int v = edge.f; int u = edge.s; g[v].pb(u); g[u].pb(v); } A = get_farthest(1, 1); B = get_farthest(A, A); for (int v = B; v != A; v = parent[v]) DIAM.pb(v); DIAM.pb(A); for (int i = 0; i < DIAM.size(); i++) { int v = DIAM[i]; in_DIAM[v] = 1; in_DIAM_pos[v] = i; } for (int i = 0; i < DIAM.size(); i++) { int v = DIAM[i]; dfs(v, v); calc_self(v, v); } for (int i = 0, farthest = 0; i < DIAM.size(); i++) { int v = DIAM[i]; int prev = 0; left_val[v] = max(left_val[v], self_val[v]); if (i) { prev = left_val[DIAM[i - 1]]; farthest = farthest + 1; left_val[v] = max(left_val[v], prev); left_val[v] = max(left_val[v], farthest + dp[v].farthest); } k_max maxs = k_max(2); for (auto to : g[v]) { if (in_DIAM[to]) continue; maxs.add(dp[to].farthest + 1, to); } for (auto to : g[v]) { if (in_DIAM[to]) continue; calc_left(to, v, max(prev, farthest + maxs.get({to}))); } farthest = max(farthest, dp[v].farthest); } for (int i = DIAM.size() - 1, farthest = 0; i >= 0; i--) { int v = DIAM[i]; int prev = 0; right_val[v] = max(right_val[v], self_val[v]); if (i + 1 < DIAM.size()) { prev = right_val[DIAM[i + 1]]; farthest = farthest + 1; right_val[v] = max(right_val[v], prev); right_val[v] = max(right_val[v], farthest + dp[v].farthest); } k_max maxs = k_max(2); for (auto to : g[v]) { if (in_DIAM[to]) continue; maxs.add(dp[to].farthest + 1, to); } for (auto to : g[v]) { if (in_DIAM[to]) continue; calc_right(to, v, max(prev, farthest + maxs.get({to}))); } farthest = max(farthest, dp[v].farthest); } fen lefts, rights; for (int i = 1; i <= n; i++) { int pos = in_DIAM_pos[diam_parent[i]]; lefts.inc(pos, 1); rights.inc(pos, 1); events[m++] = {left_val[i], {0, i}}; events[m++] = {right_val[i], {1, i}}; } for (int i = 0; i < DIAM.size(); i++) { int v = DIAM[i]; events[m++] = {self_val[v], {2, v}}; } sort(events, events + m); reverse(events, events + m); set<int> blocked; blocked.insert(0); blocked.insert(DIAM.size() - 1); for (int i = 0; i < m; i++) { int val = events[i].f; int type = events[i].s.f; int pos = in_DIAM_pos[diam_parent[events[i].s.s]]; if (type == 0) { int l = pos + 1; int r = (*blocked.upper_bound(pos)); ans[val] += rights.get(l, r); lefts.inc(pos, -1); } if (type == 1) { int l = (*(--blocked.lower_bound(pos))); int r = pos - 1; ans[val] += lefts.get(l, r); rights.inc(pos, -1); } if (type == 2) { int l = (*(--blocked.lower_bound(pos))); int r = (*blocked.upper_bound(pos)); ans[val] += lefts.get(l, pos - 1) * rights.get(pos + 1, r); blocked.insert(pos); } } for (auto v : DIAM) ans[DIAM.size() - 1] += 1ll * sz[v] * (sz[v] - 1) / 2; k_max selfs = k_max(3); for (auto v : DIAM) selfs.add(self_val[v], v); vector<ll> result; for (auto query : queries) { if (query.f > 0) { int v = query.f; int u = query.s; int val = 0; if (diam_parent[v] == diam_parent[u]) val = max(val, (int)DIAM.size() - 1); if (in_DIAM_pos[diam_parent[v]] > in_DIAM_pos[diam_parent[u]]) swap(v, u); val = max(val, left_val[v]); val = max(val, right_val[u]); val = max(val, selfs.get({diam_parent[u], diam_parent[v]})); result.pb(val); } else { int k = query.s; result.pb(ans[k]); } } return result; } int main() { int n, m; vector<pii> edges; vector<pii> queries; scanf("%d", &n); for (int i = 1; i < n; i++) { int v, u; scanf("%d%d", &v, &u); edges.pb({v, u}); } scanf("%d", &m); for (int i = 0; i < m; i++) { int type; scanf("%d", &type); if (type == 1) { int v, u; scanf("%d%d", &v, &u); queries.pb({v, u}); } else { int k; scanf("%lld", &k); queries.pb({0, k}); } } vector<ll> ans = solve(n, edges, queries); for (auto val : ans) printf("%lld\n", val); return 0; }