Centroid Decomposition

mail@pastecode.io avatar
unknown
c_cpp
2 years ago
2.5 kB
9
Indexable
Never
class CentroidDecomposition{
public:
  vector<int> sz;
  vector<int> depth;
  vector<vector<int> > par;

  vector<int> level;
  vector<bool> isCenter;
  vector<int> centPar;

  vector<int> distToRed;
  CentroidDecomposition(int MAXN){
    sz.resize(MAXN);
    depth.resize(MAXN);
    level.resize(MAXN);
    isCenter.resize(MAXN);
    par.resize(MAXN);
    centPar.resize(MAXN);
    int LOG = log2(MAXN);
    for(int i = 0; i < MAXN; ++i)
      par[i].resize(LOG + 1);

  }
  void dfs(int v, int d, int p = 0) {
      depth[v] = d;

      par[v][0] = p;
      for (int i = 1; i < 20; i++)
          par[v][i] = par[par[v][i - 1]][i - 1];

      for (int to : g[v])
          if (to != p)
              dfs(to, d + 1, v);
  }

  int lca(int a, int b) {
      if (depth[a] < depth[b])
          swap(a, b);
      for (int i = 19; i >= 0; i--)
          if (depth[par[a][i]] >= depth[b])
              a = par[a][i];
      if (a == b)
          return a;
      for (int i = 19; i >= 0; i--)
          if (par[a][i] != par[b][i]) {
              a = par[a][i];
              b = par[b][i];
          }
      return par[a][0];
  }

  int dist(int a, int b) {
      int lc = lca(a, b);
      return depth[a] + depth[b] - 2 * depth[lc];
  }

  int findCentroid(int v, int compSz, int p = -1) {
      for (int to : g[v])
          if (!isCenter[to] && to != p) {
              if (sz[to] > compSz / 2)
                  return findCentroid(to, compSz, v);
          }
      return v;
  }

  int calcSz(int v, int p = -1) {
      sz[v] = 1;
      // Change the graph name to your own
      for (int to : g[v])
          if (!isCenter[to] && to != p)
              sz[v] += calcSz(to, v);
      return sz[v];
  }

  void centroids(int v, int lv, int prev = -1) {
      calcSz(v);

      int center = findCentroid(v, sz[v]);
      isCenter[center] = true;
      centPar[center] = prev;
      level[center] = lv;

      for (int to : g[center])
          if (!isCenter[to])
              centroids(to, lv + 1, center);
  }

  void paint(int v) {
      int a = v;
      for (int lv = level[v]; lv >= 0; lv--) {
          distToRed[a] = min(distToRed[a], dist(a, v));
          a = centPar[a];
      }
  }

  int get(int v) {
      int res = INT_MAX;
      int a = v;
      for (int lv = level[v]; lv >= 0; lv--) {
          res = min(res, distToRed[a] + dist(a, v));
          a = centPar[a];
      }
      return res;
  }
};