Untitled

 avatar
unknown
c_cpp
17 days ago
2.8 kB
2
Indexable
#include <bits/stdc++.h>
using namespace std;
#define fastio() ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
#define LSB(i) ((i) & (-i))
#define ll long long
const int dx[]{-1,1,0,0,-1,-1,1,1};
const int dy[]{0,0,1,-1,-1,1,-1,1};
const int MOD = 1e9+7;
#define int ll


const int N = 2e5+5;

std::vector<int> G1[N], G2[N];
int h1[N], h2[N], pre[N];

int dfs0(int u, int p){
    for(int v : G1[u]){
        if(v==p) continue;
        int nei = dfs0(v, u);
        h1[u] = max(h1[u], nei);
    }
    return h1[u] + 1;
}
void dfs1(int u, int p, int up){
    h1[u] = max(h1[u], up);
    multiset<int> s;
    for(int v : G1[u]){
        if(v==p) continue;
        s.insert(h1[v]);
        if(s.size() > 2) s.erase(s.begin());
    }
    for(int v : G1[u]){
        if(v==p) continue;
        int nxt = up;
        if(s.size() > 1) nxt = max(nxt, 1 + (*s.rbegin() == h1[v] ? *s.begin() : *s.rbegin()));
        dfs1(v, u, nxt + 1);
    }
}

int dfs2(int u, int p){
    for(int v : G2[u]){
        if(v==p) continue;
        h2[u] = max(h2[u], dfs2(v, u));
    }
    return h2[u] + 1;
}

void dfs3(int u, int p, int up){
    h2[u] = max(h2[u], up);
    multiset<int> s;
    for(int v : G2[u]){
        if(v==p) continue;
        s.insert(h2[v]);
        if(s.size() > 2) s.erase(s.begin());
    }
     for(int v : G2[u]){
        if(v==p) continue;
        int nxt = up;
        if(s.size() > 1) nxt = max(nxt, 1 + (*s.rbegin() == h2[v] ? *s.begin() : *s.rbegin()));
        dfs3(v, u, nxt + 1);
    }
}



void solve(int testCase)  {
    int n1, n2;
    cin >> n1;
    for(int i = 1; i < n1; ++i){
        int a, b;
        cin >> a >> b;
        --a, --b;
        G1[a].push_back(b);
        G1[b].push_back(a);
    }
    cin >> n2;
    for(int i = 1; i < n2; ++i){
        int a, b;
        cin >> a >> b;
        --a, --b;
        G2[a].push_back(b);
        G2[b].push_back(a);
    }
    dfs0(0, 0);
    dfs1(0, 0, 0);
    dfs2(0, 0);
    dfs3(0, 0, 0);
    int s = 0, ans = 0;
    for(int i = 0; i < n2; ++i) s += h2[i];
    int mx = 0;
    for(int i = 0; i < n1; ++i) mx = max(mx, h1[i]);
    for(int i = 0; i < n2; ++i) mx = max(mx, h2[i]);
    sort(h1, h1 + n1);
    sort(h2, h2 + n2);
    pre[0] = h2[0];
    for(int i = 1; i < n2; ++i) pre[i] = h2[i] + pre[i-1];
    for(int i = 0; i < n1; ++i){
        int up = upper_bound(h2, h2 + n2, mx - h1[i] - 1) - h2;
        ans += mx * up;
        int other = n2 - up;
        ans += (h1[i] + 1) * other + pre[n2-1] - (up ? pre[up-1] : 0);
    }
    cout << ans;
} 

int32_t main(){
    fastio();  
    int t = 1;
  //   cin >> t;
    for(int i = 1; i <= t; ++i){
        solve(i);
    }
    return 0;
}
Editor is loading...
Leave a Comment