KNN - User

 avatar
unknown
plain_text
a year ago
3.1 kB
3
Indexable
#include<list>
#include<unordered_map>
#include<queue>
#include<set>
 
const int MAX_ADD = 20005;
const int MAX_DEL = 1000;
const int MAX_PREIDICT = 10000;
 
using namespace std;
struct Node {
    int id;
    int row, col, value;    
    int cur_dist;
    bool isDel;
    void set(int mID, int mX, int mY, int mC){
        col = mX;
        row = mY;
        id = mID;
        value = mC;
        isDel = false;
    }
    void calNewDist(int mX, int mY){
        int a = col-mX;
        int b = row-mY;
        if (a < 0) a *= -1;
        if (b < 0) b *= -1;
        cur_dist = a+b;
    }
};
int idx_pool;
Node pool_node[MAX_ADD]; // 0-MAX_ADD
 
 
int KK; // K
int LL; // L
 
list<int> listNode[40][40]; //1-4000 : id
unordered_map<int, int> mapNode; // id - idx_node
 
void init(int K, int L){
 
    mapNode.clear();
    KK = K;
    LL = L;
    idx_pool = 0;
    for (int i = 0; i < 40; i++)
        for (int j = 0; j < 40; j++)
            listNode[i][j].clear();
    return;
}
 
void addSample(int mID, int mX, int mY, int mC){
    pool_node[idx_pool].set(mID, mX, mY, mC);
    mapNode[mID] = idx_pool;
     
    listNode[(mY-1)/100][(mX-1)/100].push_back(idx_pool);
 
    idx_pool++;
    return;
}
 
void deleteSample(int mID){
    int idx = mapNode[mID];
    pool_node[idx].isDel = true;
 
    return;
}
 
struct cmp {
    bool operator () (int idx1, int idx2){
        if (pool_node[idx1].cur_dist == pool_node[idx2].cur_dist){
            if (pool_node[idx1].col == pool_node[idx2].col)
                return pool_node[idx1].row > pool_node[idx2].row;      
            else
                return pool_node[idx1].col > pool_node[idx2].col;
        }
        else return pool_node[idx1].cur_dist > pool_node[idx2].cur_dist;
    }
};
 
int res[11] = {0}; // 1-10
int predict(int mX, int mY){
    for (int i = 1; i <= 10; i++)
        res[i] = 0;
     
 
    priority_queue<int, vector<int>, cmp> heap;
 
	int bgR = (mY-1)/100 - 1; if (bgR < 0) bgR = 0;
    int bgC = (mX-1)/100 - 1; if (bgC < 0) bgC = 0;
    int eR = (mY-1)/100 + 1; if (eR > 39) eR = 39;
    int eC = (mX-1)/100 + 1; if (eC > 39) eC = 39;
 
    for (int r = bgR; r <= eR; r++){
        for (int c = bgC; c <= eC; c++){
            //if(listNode[r][c].size() > maxx) maxx = listNode[r][c].size();
            for (auto it : listNode[r][c]){
                int i = it;
                if (pool_node[i].isDel) continue;
                pool_node[i].calNewDist(mX,mY);
                int x = pool_node[i].cur_dist;
                if (x > LL) continue;
                heap.push(i);
            }
        }
    }
    if (heap.size() < KK) return -1;
    int ka = KK;
    while (ka--){
        int idx = heap.top();
        heap.pop();
        res[pool_node[idx].value]++;
    }
 
    int max = 0; 
    int eqI = 0;
    for (int i = 1; i <= 10; i++){
        if (res[i] > max){
            max = res[i];
            eqI = i;
        }
    }
 
    return eqI; // 1-10
}
Editor is loading...
Leave a Comment