Untitled

 avatar
unknown
plain_text
2 years ago
8.8 kB
4
Indexable
// include standard input/output library
#include <iostream>
#include <queue>
#include <vector>

// use standard namespace
using namespace std;

// Depict k-dimensional point
template <int K, typename T> struct Point {
  vector<T> coord;
  // constructor to initialize the coordinate vector
  Point(vector<T> coord) : coord(coord) {
    // if the size of the coordinate vector is not equal to K, throw an
    // exception
    if (coord.size() != K) {
      throw invalid_argument("Point must have K coordinates");
    }
  }
  // overload the == operator to compare two points
  bool operator==(const Point &other) {
    // if the size of the coordinate vectors is not equal, return false
    if (coord.size() != other.coord.size()) {
      return false;
    }
    // compare each coordinate of the two points
    for (int i = 0; i < coord.size(); i++) {
      if (coord[i] != other.coord[i]) {
        return false;
      }
    }
    return true;
  }
  // calculate the L2 norm between two points
  inline double l2Norm(const Point &other) const {
    double euclidian = 0;
    // calculate the sum of the squared differences between each coordinate of
    // the two points
    for (int i = 0; i < coord.size(); i++) {
      euclidian += pow(coord[i] - other.coord[i], 2);
    }
    return euclidian;
  };
  // print the coordinates of the point
  void show() {
    cout << "Point: [[[ ";
    for (int i = 0; i < coord.size(); i++) {
      cout << coord[i] << " ";
    }
    cout << "]]]" << endl;
  }
};

// K-dimensional tree holds points
// each of k-dimensions
// allowing for O(log N) insert and query
// where N is the total number of points
// in the tree
template <int K, typename T> class KDTree {
  struct compareDist {
    Point<K, T> query;
    compareDist(Point<K, T> query) : query(query) {}

    bool operator()(const pair<double, Point<K, T>> &a,
                    const pair<double, Point<K, T>> &b) {
      return a.second.l2Norm(query) < b.second.l2Norm(query);
    }
  };

public:
  // constructor to build the tree from a vector of points
  KDTree(vector<Point<K, T>> points, int depth = 0)
      : median(nullopt), left(nullptr), right(nullptr), depth(depth) {
    // if there are no points, return
    if (points.empty()) {
      return;
    }

    // sort the points and pick the median
    int dimension = depth % K;
    sort(points.begin(), points.end(),
         [&dimension](const Point<K, T> &a, const Point<K, T> &b) {
           return a.coord[dimension] < b.coord[dimension];
         });

    // set median
    int medianIndex = points.size() / 2;
    median = points[medianIndex];

    // create left and right subtrees
    if (medianIndex > 0) {
      left = new KDTree<K, T>(
          vector<Point<K, T>>(points.begin(), points.begin() + medianIndex),
          depth + 1);
    }
    if (medianIndex + 1 < points.size()) {
      right = new KDTree<K, T>(
          vector<Point<K, T>>(points.begin() + medianIndex + 1, points.end()),
          depth + 1);
    }
  }

  // destructor to free memory
  ~KDTree() {
    delete left;
    delete right;
  }

  // print the tree
  void show() {
    cout << string(depth * depth, ' ');
    if (median) {
      median->show();
    } else {
      cout << "Empty Node" << endl;
    }
    if (left) {
      left->show();
    }
    if (right) {
      right->show();
    }
  }

  // find the nearest neighbor of a query point
  optional<Point<K, T>>
  nearestNeighbor(Point<K, T> query, optional<Point<K, T>> best = nullopt,
                  double bestDistance = numeric_limits<double>::infinity()) {
    if (!median) {
      return nullopt;
    }
    // calculate the distance between the query point and the median point
    double distance = query.l2Norm(*median);
    // if the distance is less than the current best distance, update the best
    // distance and best point
    if (distance < bestDistance) {
      bestDistance = distance;
      best = median;
    }

    // get pointers to next and other branches
    int dimension = depth % K;
    KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr;

    // determine which branch to search next based on the current dimension
    if (query.coord[dimension] < median->coord[dimension]) {
      nextBranch = left;
      otherBranch = right;
    } else {
      nextBranch = right;
      otherBranch = left;
    }

    // recursively search the next branch
    if (nextBranch) {
      best = nextBranch->nearestNeighbor(query, best, bestDistance);
      bestDistance = query.l2Norm(*best);
    }
    // if the distance between the query point and the splitting plane is less
    // than the current best distance, recursively search the other branch
    if (otherBranch and
        bestDistance > abs(query.coord[dimension] - median->coord[dimension])) {
      auto otherBest = otherBranch->nearestNeighbor(query, best, bestDistance);
      if (query.l2Norm(*otherBest) < bestDistance) {
        best = otherBest;
      }
    }

    return best;
  }

  vector<Point<K, T>> kNearestNeighbors(Point<K, T> query,
                                        int nearestNeighbors = 1) {
    priority_queue<pair<double, Point<K, T>>, vector<pair<double, Point<K, T>>>,
                   compareDist>
        nearest((compareDist(query)));

    kNearestNeighborsHelper(query, nearest, nearestNeighbors);

    vector<Point<K, T>> neighbors;
    while (!nearest.empty()) {
      neighbors.push_back(nearest.top().second);
      nearest.pop();
    }
    reverse(neighbors.begin(), neighbors.end());
    return neighbors;
  }

  // find the nearest neighbor of a query point
  // vector<optional<Point<K, T>>>
  void kNearestNeighborsHelper(
      Point<K, T> &query,
      priority_queue<pair<double, Point<K, T>>,
                     vector<pair<double, Point<K, T>>>, compareDist> &nearest,
      int nearestNeighbors = 1) {
    if (!median) {
      return;
    }
    // calculate the distance between the query point and the median point
    double distance = query.l2Norm(*median);
    // if the distance lesser than the n-th nearest, add it into the heap
    // if number of points is less than n, add it into the heap
    if (nearest.size() < nearestNeighbors or distance < nearest.top().first) {
      if (nearest.size() == nearestNeighbors) {
        nearest.pop();
      }
      nearest.push({distance, *median});
    }

    // get pointers to next and other branches
    int dimension = depth % K;
    KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr;

    // determine which branch to search next based on the current dimension
    if (query.coord[dimension] < median->coord[dimension]) {
      nextBranch = left;
      otherBranch = right;
    } else {
      nextBranch = right;
      otherBranch = left;
    }

    // recursively search the next branch
    if (nextBranch) {
      nextBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors);
    }

    double distanceToPlane =
        pow(query.coord[dimension] - median->coord[dimension], 2);
    // if the distance between the query point and the splitting plane is less
    // than the current best distance, recursively search the other branch
    if (otherBranch and distanceToPlane < nearest.top().first) {
      otherBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors);
    }
  }

private:
  optional<Point<K, T>> median;
  KDTree<K, T> *left, *right;
  int depth;
};

template <int K, typename T> struct TestCase {
  struct In {
    vector<Point<K, T>> points;
    Point<K, T> query;
  };
  In in;
  struct Want {
    optional<Point<K, T>> best;
  };
  Want want;
};

void emptyLine(int n = 1) {
  for (int i = 0; i < n; i++) {
    cout << endl;
  }
}

int main() {
  // create a 2-dimensional point vector
  using Point2dInt = Point<2, int>;
  vector<Point2dInt> points2d;
  int limit2d = 11;
  for (int i = 1; i < limit2d; i++) {
    points2d.push_back(Point2dInt({i, i}));
    points2d.push_back(Point2dInt({i, limit2d - i}));
  }

  // build a 2-dimensional KD tree from the point vector
  auto t2 = KDTree<2, int>(points2d);
  // print the tree
  // t2.show();
  // // find the nearest neighbor of {5, 0}
  cout << "Nearest Neighbor of {5, 0} is ";
  t2.nearestNeighbor(Point2dInt({5, 0}))->show();
  emptyLine(2);
  cout << "1 Nearest Neighbor of {5, 0} is " << endl;
  for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 1)) {
    neighbor.show();
    cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl;
  }
  // find 5 nearest neighbors of {5, 0}
  emptyLine(5);
  cout << "5 Nearest Neighbors of {5, 0} are " << endl;
  for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 5)) {
    neighbor.show();
    cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl;
  }
  return 0;
}
Editor is loading...
Leave a Comment