Untitled
unknown
plain_text
2 years ago
8.8 kB
5
Indexable
// include standard input/output library
#include <iostream>
#include <queue>
#include <vector>
// use standard namespace
using namespace std;
// Depict k-dimensional point
template <int K, typename T> struct Point {
vector<T> coord;
// constructor to initialize the coordinate vector
Point(vector<T> coord) : coord(coord) {
// if the size of the coordinate vector is not equal to K, throw an
// exception
if (coord.size() != K) {
throw invalid_argument("Point must have K coordinates");
}
}
// overload the == operator to compare two points
bool operator==(const Point &other) {
// if the size of the coordinate vectors is not equal, return false
if (coord.size() != other.coord.size()) {
return false;
}
// compare each coordinate of the two points
for (int i = 0; i < coord.size(); i++) {
if (coord[i] != other.coord[i]) {
return false;
}
}
return true;
}
// calculate the L2 norm between two points
inline double l2Norm(const Point &other) const {
double euclidian = 0;
// calculate the sum of the squared differences between each coordinate of
// the two points
for (int i = 0; i < coord.size(); i++) {
euclidian += pow(coord[i] - other.coord[i], 2);
}
return euclidian;
};
// print the coordinates of the point
void show() {
cout << "Point: [[[ ";
for (int i = 0; i < coord.size(); i++) {
cout << coord[i] << " ";
}
cout << "]]]" << endl;
}
};
// K-dimensional tree holds points
// each of k-dimensions
// allowing for O(log N) insert and query
// where N is the total number of points
// in the tree
template <int K, typename T> class KDTree {
struct compareDist {
Point<K, T> query;
compareDist(Point<K, T> query) : query(query) {}
bool operator()(const pair<double, Point<K, T>> &a,
const pair<double, Point<K, T>> &b) {
return a.second.l2Norm(query) < b.second.l2Norm(query);
}
};
public:
// constructor to build the tree from a vector of points
KDTree(vector<Point<K, T>> points, int depth = 0)
: median(nullopt), left(nullptr), right(nullptr), depth(depth) {
// if there are no points, return
if (points.empty()) {
return;
}
// sort the points and pick the median
int dimension = depth % K;
sort(points.begin(), points.end(),
[&dimension](const Point<K, T> &a, const Point<K, T> &b) {
return a.coord[dimension] < b.coord[dimension];
});
// set median
int medianIndex = points.size() / 2;
median = points[medianIndex];
// create left and right subtrees
if (medianIndex > 0) {
left = new KDTree<K, T>(
vector<Point<K, T>>(points.begin(), points.begin() + medianIndex),
depth + 1);
}
if (medianIndex + 1 < points.size()) {
right = new KDTree<K, T>(
vector<Point<K, T>>(points.begin() + medianIndex + 1, points.end()),
depth + 1);
}
}
// destructor to free memory
~KDTree() {
delete left;
delete right;
}
// print the tree
void show() {
cout << string(depth * depth, ' ');
if (median) {
median->show();
} else {
cout << "Empty Node" << endl;
}
if (left) {
left->show();
}
if (right) {
right->show();
}
}
// find the nearest neighbor of a query point
optional<Point<K, T>>
nearestNeighbor(Point<K, T> query, optional<Point<K, T>> best = nullopt,
double bestDistance = numeric_limits<double>::infinity()) {
if (!median) {
return nullopt;
}
// calculate the distance between the query point and the median point
double distance = query.l2Norm(*median);
// if the distance is less than the current best distance, update the best
// distance and best point
if (distance < bestDistance) {
bestDistance = distance;
best = median;
}
// get pointers to next and other branches
int dimension = depth % K;
KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr;
// determine which branch to search next based on the current dimension
if (query.coord[dimension] < median->coord[dimension]) {
nextBranch = left;
otherBranch = right;
} else {
nextBranch = right;
otherBranch = left;
}
// recursively search the next branch
if (nextBranch) {
best = nextBranch->nearestNeighbor(query, best, bestDistance);
bestDistance = query.l2Norm(*best);
}
// if the distance between the query point and the splitting plane is less
// than the current best distance, recursively search the other branch
if (otherBranch and
bestDistance > abs(query.coord[dimension] - median->coord[dimension])) {
auto otherBest = otherBranch->nearestNeighbor(query, best, bestDistance);
if (query.l2Norm(*otherBest) < bestDistance) {
best = otherBest;
}
}
return best;
}
vector<Point<K, T>> kNearestNeighbors(Point<K, T> query,
int nearestNeighbors = 1) {
priority_queue<pair<double, Point<K, T>>, vector<pair<double, Point<K, T>>>,
compareDist>
nearest((compareDist(query)));
kNearestNeighborsHelper(query, nearest, nearestNeighbors);
vector<Point<K, T>> neighbors;
while (!nearest.empty()) {
neighbors.push_back(nearest.top().second);
nearest.pop();
}
reverse(neighbors.begin(), neighbors.end());
return neighbors;
}
// find the nearest neighbor of a query point
// vector<optional<Point<K, T>>>
void kNearestNeighborsHelper(
Point<K, T> &query,
priority_queue<pair<double, Point<K, T>>,
vector<pair<double, Point<K, T>>>, compareDist> &nearest,
int nearestNeighbors = 1) {
if (!median) {
return;
}
// calculate the distance between the query point and the median point
double distance = query.l2Norm(*median);
// if the distance lesser than the n-th nearest, add it into the heap
// if number of points is less than n, add it into the heap
if (nearest.size() < nearestNeighbors or distance < nearest.top().first) {
if (nearest.size() == nearestNeighbors) {
nearest.pop();
}
nearest.push({distance, *median});
}
// get pointers to next and other branches
int dimension = depth % K;
KDTree<K, T> *nextBranch = nullptr, *otherBranch = nullptr;
// determine which branch to search next based on the current dimension
if (query.coord[dimension] < median->coord[dimension]) {
nextBranch = left;
otherBranch = right;
} else {
nextBranch = right;
otherBranch = left;
}
// recursively search the next branch
if (nextBranch) {
nextBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors);
}
double distanceToPlane =
pow(query.coord[dimension] - median->coord[dimension], 2);
// if the distance between the query point and the splitting plane is less
// than the current best distance, recursively search the other branch
if (otherBranch and distanceToPlane < nearest.top().first) {
otherBranch->kNearestNeighborsHelper(query, nearest, nearestNeighbors);
}
}
private:
optional<Point<K, T>> median;
KDTree<K, T> *left, *right;
int depth;
};
template <int K, typename T> struct TestCase {
struct In {
vector<Point<K, T>> points;
Point<K, T> query;
};
In in;
struct Want {
optional<Point<K, T>> best;
};
Want want;
};
void emptyLine(int n = 1) {
for (int i = 0; i < n; i++) {
cout << endl;
}
}
int main() {
// create a 2-dimensional point vector
using Point2dInt = Point<2, int>;
vector<Point2dInt> points2d;
int limit2d = 11;
for (int i = 1; i < limit2d; i++) {
points2d.push_back(Point2dInt({i, i}));
points2d.push_back(Point2dInt({i, limit2d - i}));
}
// build a 2-dimensional KD tree from the point vector
auto t2 = KDTree<2, int>(points2d);
// print the tree
// t2.show();
// // find the nearest neighbor of {5, 0}
cout << "Nearest Neighbor of {5, 0} is ";
t2.nearestNeighbor(Point2dInt({5, 0}))->show();
emptyLine(2);
cout << "1 Nearest Neighbor of {5, 0} is " << endl;
for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 1)) {
neighbor.show();
cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl;
}
// find 5 nearest neighbors of {5, 0}
emptyLine(5);
cout << "5 Nearest Neighbors of {5, 0} are " << endl;
for (auto neighbor : t2.kNearestNeighbors(Point2dInt({5, 0}), 5)) {
neighbor.show();
cout << " | Distance: " << neighbor.l2Norm(Point2dInt({5, 0})) << endl;
}
return 0;
}
Editor is loading...
Leave a Comment