Untitled
#include <iostream> #include <vector> #include <set> #include <cstring> const int mxn = 50005; const int mxl = 17; using namespace std; vector<pair<int, int>> g[mxn]; int par[mxn][mxl], sum[mxn][mxl], dep[mxn]; set<int> s; void dfs(int node, int p, int prv) { //nlogn if(p != -1) dep[node] = dep[p] + 1; par[node][0] = p, sum[node][0] = prv; for (int i = 1; i < mxl; ++i) { if (par[node][i - 1] == (-1)) break; par[node][i] = par[par[node][i - 1]][i - 1]; if(par[node][i] != (-1)) sum[node][i] = sum[node][i - 1] + sum[par[node][i - 1]][i - 1]; } for (auto x : g[node]) if (x.first != p) dfs(x.first, node, x.second); } pair<int, int> lca(int u, int v) { //logn if (dep[u] > dep[v]) swap(u, v); int diff = dep[v] - dep[u], ans = 0; for (int i = 0; i < mxl; ++i) if (diff & (1 << i)) { ans += sum[v][i]; v = par[v][i]; } if (u == v) return { u, ans }; for (int i = mxl - 1; i >= 0; --i) if (par[u][i] != par[v][i]) { ans += sum[u][i]; u = par[u][i]; ans += sum[v][i]; v = par[v][i]; } return { par[u][0], ans + sum[u][0] + sum[v][0]}; } int32_t main() { ios_base::sync_with_stdio(false); cin.tie(nullptr), cout.tie(nullptr); memset(par, -1, sizeof(par)); memset(sum, -1, sizeof(sum)); int n; cin >> n; for (int i = 1; i < n; ++i) { int u, v, w; cin >> u >> v >> w; g[u].push_back({ v, w }); g[v].push_back({ u, w }); } dfs(0, -1, -1); int q; cin >> q; while (q--) { for (int i = 0; i < 5; ++i) { int a; cin >> a; s.insert(a); } int ans = 0; while(s.size() > 1){ int d = -1, u = -1, v = -1, sum = 0, lc = 0; for (auto x : s) { for (auto y : s) { if (x != y) { auto res = lca(x, y); if (dep[res.first] > d) { d = dep[res.first], sum = res.second; u = x, v = y, lc = res.first; } } } } s.erase(u), s.erase(v); s.insert(lc); ans += sum; } cout << ans << endl; s.clear(); } }
Leave a Comment