FULL

mail@pastecode.io avatar
unknown
c_cpp
2 years ago
9.2 kB
2
Indexable
Never
 
struct diam_dp
{
    int farthest, diam;
    diam_dp()
    {
        farthest = -inf;
        diam = 0;
    }
    void add(int x)
    {
        farthest = max(x, farthest);
        diam = max(diam, farthest);
    }
    void add(diam_dp child_dp)
    {
        diam = max(diam, farthest + child_dp.farthest + 1);
        diam = max(diam, child_dp.diam);
        farthest = max(farthest, child_dp.farthest + 1);
    }
};
struct k_max
{
    pii a[5];
    int k;
    k_max(int _k)
    {
        k = _k;
        for (int i = 0; i < _k; i++)
            a[i] = {0, 0};
    }
    void add(int val, int label)
    {
        for (int i = 0; i < k; i++)
        {
            if (val >= a[i].f)
            {
                for (int j = k - 1; j > i; j--)
                    a[j] = a[j - 1];
                a[i] = {val, label};
                break;
            }
        }
    }
    int get(vector<int> labels, int cnt = 1)
    {
        int res = 0;
        for (int i = 0; i < k && cnt; i++)
        {
            bool found = 0;
            for (auto label : labels)
                found |= (a[i].s == label);
            if (!found)
                res += a[i].f, cnt--;
        }
        return res;
    }
};
struct fen
{
    ll t[N];
    fen()
    {
        for (int i = 0; i < N; i++)
            t[i] = 0;
    }
    ll sum(int r)
    {
        int result = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            result += t[r];
        return result;
    }
    ll get(int l, int r)
    {
        if (l > r)
            return 0;
        return sum(r) - sum(l - 1);
    }
    void inc(int i, int val)
    {
 
        for (; i < N; i = (i | (i + 1)))
            t[i] += val;
    }
};
 
vector<int> g[N];
int n;
ll ans[N];
 
int A, B;
vector<int> DIAM;
 
bool in_DIAM[N];
int in_DIAM_pos[N];
int diam_parent[N];
 
diam_dp dp[N];
 
int self_val[N], left_val[N], right_val[N];
int sz[N];
 
int m;
pair<int, pii> events[3 * N];
 
int depth[N], parent[N];
 
int get_farthest(int v, int p)
{
    parent[v] = p;
    if (v == p)
        depth[v] = 0;
    else
        depth[v] = depth[p] + 1;
 
    int cur_farthest = v;
    for (auto to : g[v])
    {
        if (to == p)
            continue;
        int u = get_farthest(to, v);
        if (depth[cur_farthest] < depth[u])
            cur_farthest = u;
    }
    return cur_farthest;
}
void dfs(int v, int p)
{
    if (v == p)
        diam_parent[v] = v;
    else
        diam_parent[v] = diam_parent[p];
    dp[v].add(0);
    for (auto to : g[v])
    {
        if (in_DIAM[to] || to == p)
            continue;
        dfs(to, v);
        dp[v].add(dp[to]);
    }
}
void calc_self(int v, int p, int k = 0)
{
    sz[v] = 1;
    self_val[v] = max(self_val[v], dp[v].diam);
    self_val[v] = max(self_val[v], k);
    k_max maxs = k_max(2), kek = k_max(3);
    for (auto to : g[v])
    {
        if (in_DIAM[to] || to == p)
            continue;
        maxs.add(max(dp[to].diam, dp[to].farthest + 1), to);
        kek.add(dp[to].farthest + 1, to);
    }
    for (auto to : g[v])
    {
        if (in_DIAM[to] || to == p)
            continue;
        int kk = k;
        kk = max(kk, maxs.get({to}));
        kk = max(kk, kek.get({to}, 2));
        calc_self(to, v, kk);
        sz[v] += sz[to];
    }
}
void calc_left(int v, int p, int k)
{
    left_val[v] = max(self_val[v], k);
    for (auto to : g[v])
    {
        if (in_DIAM[to] || to == p)
            continue;
        calc_left(to, v, k);
    }
}
 
void calc_right(int v, int p, int k)
{
    right_val[v] = max(self_val[v], k);
    for (auto to : g[v])
    {
        if (in_DIAM[to] || to == p)
            continue;
        calc_right(to, v, k);
    }
}
vector<ll> solve(int _n, vector<pii> edges, vector<pii> queries)
{
    n = _n;
    for (auto edge : edges)
    {
        int v = edge.f;
        int u = edge.s;
        g[v].pb(u);
        g[u].pb(v);
    }
    A = get_farthest(1, 1);
    B = get_farthest(A, A);
    for (int v = B; v != A; v = parent[v])
        DIAM.pb(v);
    DIAM.pb(A);
    for (int i = 0; i < DIAM.size(); i++)
    {
        int v = DIAM[i];
        in_DIAM[v] = 1;
        in_DIAM_pos[v] = i;
    }
 
    for (int i = 0; i < DIAM.size(); i++)
    {
        int v = DIAM[i];
        dfs(v, v);
        calc_self(v, v);
    }
    for (int i = 0, farthest = 0; i < DIAM.size(); i++)
    {
        int v = DIAM[i];
        int prev = 0;
        left_val[v] = max(left_val[v], self_val[v]);
        if (i)
        {
            prev = left_val[DIAM[i - 1]];
            farthest = farthest + 1;
            left_val[v] = max(left_val[v], prev);
            left_val[v] = max(left_val[v], farthest + dp[v].farthest);
        }
        k_max maxs = k_max(2);
        for (auto to : g[v])
        {
            if (in_DIAM[to])
                continue;
            maxs.add(dp[to].farthest + 1, to);
        }
        for (auto to : g[v])
        {
            if (in_DIAM[to])
                continue;
            calc_left(to, v, max(prev, farthest + maxs.get({to})));
        }
        farthest = max(farthest, dp[v].farthest);
    }
 
    for (int i = DIAM.size() - 1, farthest = 0; i >= 0; i--)
    {
        int v = DIAM[i];
        int prev = 0;
        right_val[v] = max(right_val[v], self_val[v]);
        if (i + 1 < DIAM.size())
        {
            prev = right_val[DIAM[i + 1]];
            farthest = farthest + 1;
            right_val[v] = max(right_val[v], prev);
            right_val[v] = max(right_val[v], farthest + dp[v].farthest);
        }
        k_max maxs = k_max(2);
        for (auto to : g[v])
        {
            if (in_DIAM[to])
                continue;
            maxs.add(dp[to].farthest + 1, to);
        }
        for (auto to : g[v])
        {
            if (in_DIAM[to])
                continue;
            calc_right(to, v, max(prev, farthest + maxs.get({to})));
        }
        farthest = max(farthest, dp[v].farthest);
    }
 
    fen lefts, rights;
    for (int i = 1; i <= n; i++)
    {
        int pos = in_DIAM_pos[diam_parent[i]];
        lefts.inc(pos, 1);
        rights.inc(pos, 1);
        events[m++] = {left_val[i], {0, i}};
        events[m++] = {right_val[i], {1, i}};
    }
    for (int i = 0; i < DIAM.size(); i++)
    {
        int v = DIAM[i];
        events[m++] = {self_val[v], {2, v}};
    }
    sort(events, events + m);
    reverse(events, events + m);
 
    set<int> blocked;
    blocked.insert(0);
    blocked.insert(DIAM.size() - 1);
 
    for (int i = 0; i < m; i++)
    {
        int val = events[i].f;
        int type = events[i].s.f;
        int pos = in_DIAM_pos[diam_parent[events[i].s.s]];
        if (type == 0)
        {
            int l = pos + 1;
            int r = (*blocked.upper_bound(pos));
            ans[val] += rights.get(l, r);
            lefts.inc(pos, -1);
        }
        if (type == 1)
        {
            int l = (*(--blocked.lower_bound(pos)));
            int r = pos - 1;
            ans[val] += lefts.get(l, r);
            rights.inc(pos, -1);
        }
        if (type == 2)
        {
            int l = (*(--blocked.lower_bound(pos)));
            int r = (*blocked.upper_bound(pos));
            ans[val] += lefts.get(l, pos - 1) * rights.get(pos + 1, r);
            blocked.insert(pos);
        }
    }
    for (auto v : DIAM)
        ans[DIAM.size() - 1] += 1ll * sz[v] * (sz[v] - 1) / 2;
 
    k_max selfs = k_max(3);
    for (auto v : DIAM)
        selfs.add(self_val[v], v);
 
    vector<ll> result;
 
    for (auto query : queries)
    {
        if (query.f > 0)
        {
            int v = query.f;
            int u = query.s;
            int val = 0;
            if (diam_parent[v] == diam_parent[u])
                val = max(val, (int)DIAM.size() - 1);
            if (in_DIAM_pos[diam_parent[v]] > in_DIAM_pos[diam_parent[u]])
                swap(v, u);
            val = max(val, left_val[v]);
            val = max(val, right_val[u]);
            val = max(val, selfs.get({diam_parent[u], diam_parent[v]}));
            result.pb(val);
        }
        else
        {
            int k = query.s;
            result.pb(ans[k]);
        }
    }
    return result;
}
int main()
{
    int n, m;
    vector<pii> edges;
    vector<pii> queries;
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int v, u;
        scanf("%d%d", &v, &u);
        edges.pb({v, u});
    }
    scanf("%d", &m);
    for (int i = 0; i < m; i++)
    {
        int type;
        scanf("%d", &type);
        if (type == 1)
        {
            int v, u;
            scanf("%d%d", &v, &u);
            queries.pb({v, u});
        }
        else
        {
            int k;
            scanf("%lld", &k);
            queries.pb({0, k});
        }
    }
    vector<ll> ans = solve(n, edges, queries);
    for (auto val : ans)
        printf("%lld\n", val);
    return 0;
}