Untitled

mail@pastecode.io avatar
unknown
java
23 days ago
5.6 kB
3
Indexable
Never
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Solution {
    static final int N = 100009, LOG = 19, MOD = 1000000007;
    static int BLOCK_SIZE;

    static int[] C = new int[N], dep = new int[N], par = new int[N], in = new int[N], out = new int[N], frqNode = new int[N];
    static int[][] up = new int[N][LOG];
    static long[] H = new long[N], cnt = new long[N];
    static ArrayList<Integer>[] adj = new ArrayList[N];
    static ArrayList<Integer> euler = new ArrayList<>();
    static int timer = 0;
    static long totFrq = 0, totH = 0, totH2 = 0, ans = 0;

    static class Info {
        long frq, sumH, sumH2;

        void apply(int h, int f) {
            long hh = mul(h, h);

            ans = add(ans, mul(mul(f, totFrq - frq), hh));
            ans = add(ans, mul(f, totH2 - sumH2));
            ans = add(ans, mul(mul(-2 * f, totH - sumH), h));

            frq = add(frq, f);
            sumH = add(sumH, f * h);
            sumH2 = add(sumH2, f * hh);

            totFrq = add(totFrq, f);
            totH = add(totH, f * h);
            totH2 = add(totH2, f * hh);
        }
    }

    static Info[] I = new Info[N];

    static class Query implements Comparable<Query> {
        int l, r, id, x;

        Query(int l, int r, int id, int x) {
            this.l = l;
            this.r = r;
            this.id = id;
            this.x = x;
        }

        @Override
        public int compareTo(Query other) {
            int n1 = l / BLOCK_SIZE, n2 = other.l / BLOCK_SIZE;
            if (n1 != n2) return n1 - n2;
            return (n1 % 2 == 0) ? r - other.r : other.r - r;
        }
    }

    static ArrayList<Query> Q = new ArrayList<>();

    public static void main(String[] args) {
        FastReader sc = new FastReader();
        int n = sc.nextInt();
        int q = sc.nextInt();

        for (int i = 0; i < n; ++i) {
            adj[i] = new ArrayList<>();
            I[i] = new Info();
        }

        for (int i = 0; i < n; ++i) C[i] = sc.nextInt();

        for (int i = 0; i < n; ++i) H[i] = sc.nextLong();

        for (int i = 1; i < n; ++i) {
            int u = sc.nextInt() - 1;
            int v = sc.nextInt() - 1;
            adj[u].add(v);
            adj[v].add(u);
        }

        dfs(0, 0);

        BLOCK_SIZE = (int) Math.sqrt(n) + 1;

        for (int i = 0; i < q; ++i) {
            int u = sc.nextInt() - 1;
            int v = sc.nextInt() - 1;
            genQuery(u, v, i);
        }

        runMo();

        for (int i = 0; i < q; ++i) {
            System.out.println(cnt[i]);
        }
    }

    static void dfs(int u, int p) {
        par[u] = p;
        up[u][0] = p;
        for (int x = 1; x < LOG; ++x) {
            up[u][x] = up[up[u][x - 1]][x - 1];
        }

        euler.add(u);
        in[u] = timer++;

        for (int v : adj[u]) {
            if (v != p) {
                dep[v] = dep[u] + 1;
                dfs(v, u);
            }
        }

        euler.add(u);
        out[u] = timer++;
    }

    static long mul(long a, long b) {
        return a * b % MOD;
    }

    static long add(long a, long b) {
        a = (a + b) % MOD;
        if (a < 0) a += MOD;
        return a;
    }

    static void doWork(int ind) {
        int c = C[ind];
        int h = (int) H[ind];
        frqNode[ind] ^= 1;
        if (frqNode[ind] == 1) I[c].apply(h, 1);
        else I[c].apply(h, -1);
    }

    static void genQuery(int a, int b, int id) {
        if (in[a] > in[b]) {
            int temp = a;
            a = b;
            b = temp;
        }
        int u = a, v = b;
        if (dep[u] < dep[v]) {
            int temp = u;
            u = v;
            v = temp;
        }
        int k = dep[u] - dep[v];
        for (int x = 0; x < LOG; ++x) {
            if ((k >> x) % 2 == 1) u = up[u][x];
        }
        if (u == v) {
            Q.add(new Query(in[a], in[b], id, -1));
            return;
        }
        for (int x = LOG - 1; x >= 0; --x) {
            if (up[u][x] != up[v][x]) {
                u = up[u][x];
                v = up[v][x];
            }
        }
        int p = up[u][0];
        Q.add(new Query(out[a], in[b], id, p));
    }

    static void runMo() {
        Collections.sort(Q);
        int mo_l = 0, mo_r = -1;
        for (Query q : Q) {
            while (mo_l > q.l) doWork(euler.get(--mo_l));
            while (mo_r < q.r) doWork(euler.get(++mo_r));
            while (mo_l < q.l) doWork(euler.get(mo_l++));
            while (mo_r > q.r) doWork(euler.get(mo_r--));

            if (q.x != -1) doWork(q.x);
            cnt[q.id] = ans;
            if (q.x != -1) doWork(q.x);
        }
    }
}

class FastReader {
    BufferedReader br;
    StringTokenizer st;

    public FastReader() {
        br = new BufferedReader(new InputStreamReader(System.in));
    }

    String next() {
        while(st == null || !st.hasMoreElements()) {
            try {
                st = new StringTokenizer(br.readLine());
            } catch(IOException e) {
                e.printStackTrace();
            }
        }
        return st.nextToken();
    }

    int nextInt() {
        return Integer.parseInt(next());
    }

    long nextLong() {
        return Long.parseLong(next());
    }

    double nextDouble() {
        return Double.parseDouble(next());
    }

    String nextLine() {
        String str = "";
        try {
            str = br.readLine();
        } catch(IOException e){
            e.printStackTrace();
        }
        return str;
    }
}
Leave a Comment