Untitled
unknown
plain_text
2 years ago
8.8 kB
4
Indexable
// include standard input/output library #include <iostream> #include <queue> #include <vector> // use standard namespace using namespace std; // Depict k-dimensional point template <int K, typename T> struct Point { vector<T> coord; // constructor to initialize the coordinate vector Point(vector<T> coord) : coord(coord) { // if the size of the coordinate vector is not equal to K, throw an // exception if (coord.size() != K) { throw invalid_argument("Point must have K coordinates"); } } // overload the == operator to compare two points bool operator==(const Point &other) { // if the size of the coordinate vectors is not equal, return false if (coord.size() != other.coord.size()) { return false; } // compare each coordinate of the two points for (int i = 0; i < coord.size(); i++) { if (coord[i] != other.coord[i]) { return false; } } return true; } // calculate the L2 norm between two points inline double l2Norm(const Point &other) const { double euclidian = 0; // calculate the sum of the squared differences between each coordinate of // the two points for (int i = 0; i < coord.size(); i++) { euclidian += pow(coord[i] - other.coord[i], 2); } return euclidian; }; // print the coordinates of the point void show() { cout << "Point: [[[ "; for (int i = 0; i < coord.size(); i++) { cout << coord[i] << " "; } cout << "]]]" << endl; } }; // K-dimensional tree holds points // each of k-dimensions // allowing for O(log N) insert and query // where N is the total number of points // in the tree template <int K, typename T> class KDTree { struct compareDist { Point<K, T> query; compareDist(Point<K, T> query) : query(query) {} bool operator()(const pair<double, Point<K, T>> &a, const pair<double, Point<K, T>> &b) { return a.second.l2Norm(query) < b.second.l2Norm(query); } }; public: // constructor to build the tree from a vector of points KDTree(vector<Point<K, T>> points, int depth = 0) : median(nullopt), left(nullptr), right(nullptr), depth(depth) { // if there are no points, return if (points.empty()) { return; } // sort the points and pick the median int dimension = depth % K; sort(points.begin(), points.end(), [&dimension](const Point<K, T> &a, const Point<K, T> &b) { return a.coord[dimension] < b.coord[dimension]; }); // set median int medianIndex = points.size() / 2; median = points[medianIndex]; // create left and right subtrees if (medianIndex > 0) { left = new KDTree<K, T>( vector<Point<K, T>>(points.begin(), points.begin() + medianIndex), depth + 1); } if (medianIndex + 1 < points.size()) { right = new KDTree<K, T>( vector<Point<K, T>>(points.begin() + medianIndex + 1, points.end()), depth + 1); } } // destructor to free memory ~KDTree() { delete left; delete right; } // print the tree void show() { cout << string(depth * depth, ' '); if (median) { median->show(); } else { cout << "Empty Node" << endl; } if (left) { left->show(); } if (right) { right->show(); } } // find the nearest neighbor of a query point optional<Point<K, T>> nearestNeighbor(Point<K, T> query, optional<Point<K, T>> best = nullopt, double bestDistance = numeric_limits<double>::infinity()) { if (!median) { return nullopt; } // calculate the distance between the query point and the median point double distance = query.l2Norm(*median); // if the distance is less than the current best distance, update the best // distance and best point if (distance < bestDistance) { bestDistance = distance; best = median; } // get pointers to next and other branches int dimension = depth % K; KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr; // determine which branch to search next based on the current dimension if (query.coord[dimension] < median->coord[dimension]) { nextBranch = left; otherBranch = right; } else { nextBranch = right; otherBranch = left; } // recursively search the next branch if (nextBranch) { best = nextBranch->nearestNeighbor(query, best, bestDistance); bestDistance = query.l2Norm(*best); } // if the distance between the query point and the splitting plane is less // than the current best distance, recursively search the other branch if (otherBranch and bestDistance > abs(query.coord[dimension] - median->coord[dimension])) { auto otherBest = otherBranch->nearestNeighbor(query, best, bestDistance); if (query.l2Norm(*otherBest) < bestDistance) { best = otherBest; } } return best; } vector<Point<K, T>> kNearestNeighbors(Point<K, T> query, int nearestNeighbors = 1) { priority_queue<pair<double, Point<K, T>>, vector<pair<double, Point<K, T>>>, compareDist> nearest((compareDist(query))); kNearestNeighborsHelper(query, nearest, nearestNeighbors); vector<Point<K, T>> neighbors; while (!nearest.empty()) { neighbors.push_back(nearest.top().second); nearest.pop(); } reverse(neighbors.begin(), neighbors.end()); return neighbors; } // find the nearest neighbor of a query point // vector<optional<Point<K, T>>> void kNearestNeighborsHelper( Point<K, T> &query, priority_queue<pair<double, Point<K, T>>, vector<pair<double, Point<K, T>>>, compareDist> &nearest, int nearestNeighbors = 1) { if (!median) { return; } // calculate the distance between the query point and the median point double distance = query.l2Norm(*median); // if the distance lesser than the n-th nearest, add it into the heap // if number of points is less than n, add it into the heap if (nearest.size() < nearestNeighbors or distance < nearest.top().first) { if (nearest.size() == nearestNeighbors) { nearest.pop(); } nearest.push({distance, *median}); } // get pointers to next and other branches int dimension = depth % K; KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr; // determine which branch to search next based on the current dimension if (query.coord[dimension] < median->coord[dimension]) { nextBranch = left; otherBranch = right; } else { nextBranch = right; otherBranch = left; } // recursively search the next branch if (nextBranch) { nextBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors); } double distanceToPlane = pow(query.coord[dimension] - median->coord[dimension], 2); // if the distance between the query point and the splitting plane is less // than the current best distance, recursively search the other branch if (otherBranch and distanceToPlane < nearest.top().first) { otherBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors); } } private: optional<Point<K, T>> median; KDTree<K, T> *left, *right; int depth; }; template <int K, typename T> struct TestCase { struct In { vector<Point<K, T>> points; Point<K, T> query; }; In in; struct Want { optional<Point<K, T>> best; }; Want want; }; void emptyLine(int n = 1) { for (int i = 0; i < n; i++) { cout << endl; } } int main() { // create a 2-dimensional point vector using Point2dInt = Point<2, int>; vector<Point2dInt> points2d; int limit2d = 11; for (int i = 1; i < limit2d; i++) { points2d.push_back(Point2dInt({i, i})); points2d.push_back(Point2dInt({i, limit2d - i})); } // build a 2-dimensional KD tree from the point vector auto t2 = KDTree<2, int>(points2d); // print the tree // t2.show(); // // find the nearest neighbor of {5, 0} cout << "Nearest Neighbor of {5, 0} is "; t2.nearestNeighbor(Point2dInt({5, 0}))->show(); emptyLine(2); cout << "1 Nearest Neighbor of {5, 0} is " << endl; for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 1)) { neighbor.show(); cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl; } // find 5 nearest neighbors of {5, 0} emptyLine(5); cout << "5 Nearest Neighbors of {5, 0} are " << endl; for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 5)) { neighbor.show(); cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl; } return 0; }
Editor is loading...
Leave a Comment