Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.1 kB
2
Indexable
Never
#include <bits/stdc++.h>

class DisjointSet {
private:
    std::vector<int> rank;
    std::vector<int> parent;
    int n;
    void initializeRanksAndParents() {
        for (int i = 0; i < n; i++) {
            parent[i] = i;
        }
    }
    int find(int x) {
        while (x != parent[x]) {
            x = parent[x];
        }
        return x;
    }
public:
    DisjointSet(int n) : n(n), rank(n), parent(n) {
        initializeRanksAndParents();
    }
    bool union_sets(int x, int y) {
        int xroot = find(x);
        int yroot = find(y);
        if (xroot == yroot) {
            return false; // Unsuccessful union as already connected
        }
        if (rank[xroot] < rank[yroot]) {
            parent[xroot] = yroot;
        } else {
            parent[yroot] = xroot;
            if (rank[xroot] == rank[yroot]) {
                rank[xroot]++;
            }
        }
        return true; // Successful union
    }
    std::vector<int> getParent() {
        return parent;
    }
};
class Solution {
public:
    long long countPairs(int n, std::vector<std::vector<int>>& edges) {
        int len = edges.size();
        long long totalNoOfPairs = static_cast<long long>(n) * (n + 1) / 2;
        DisjointSet disjointSet(n);
        for (auto& edge : edges) {
            int u = edge[0];
            int v = edge[1];
            disjointSet.union_sets(u, v);
        }
        std::vector<int> parent = disjointSet.getParent();
        std::unordered_map<int, long long> hmap;
        for (int i = 0; i < n; i++) {
            int x = i;
            while (parent[x] != x) {
                x = parent[x];
            }
            parent[i] = x;
        }
        for (int i : parent) {
            long long val = hmap[i];
            hmap[i] = val + 1;
        }
        for (auto& entry : hmap) {
            long long val = entry.second;
            long long sumToSubtract = val * (val + 1) / 2;
            totalNoOfPairs -= sumToSubtract;
        }
        return totalNoOfPairs;
    }
};