constexpr int MAX_NODE = 18000;
constexpr int HASH_SIZE = 30000;
struct Node {
Node* parent;
Node* child[3];
int num;
bool valid;
} NodePool[MAX_NODE];
int PoolCnt, GroupCnt;
struct HashNode {
int key;
Node* data;
} HashTbl[HASH_SIZE];
Node* hashFind(int key) {
unsigned long h = key % HASH_SIZE;
int cnt = HASH_SIZE;
while (HashTbl[h].key != 0 && cnt--) {
if (HashTbl[h].key == key) {
return HashTbl[h].data;
}
h = (h + 1) % HASH_SIZE;
}
return nullptr;
}
void hashAdd(int key, Node* data) {
unsigned long h = key % HASH_SIZE;
while (HashTbl[h].key != 0) {
if (HashTbl[h].key == key) {
HashTbl[h].data = data;
return;
}
h = (h + 1) % HASH_SIZE;
}
HashTbl[h].key = key;
HashTbl[h].data = data;
}
void delNode(Node* node) {
if (node == nullptr) return;
node->valid = false;
for (int i = 0; i < 3; ++i)
delNode(node->child[i]);
}
void init(int N, int mId[], int mNum[]) {
PoolCnt = 0;
for (int i = 0; i < HASH_SIZE; ++i) {
HashTbl[i].key = 0;
}
GroupCnt = N;
for (int i = 0; i < N; ++i) {
Node* node = &NodePool[PoolCnt++];
hashAdd(mId[i], node);
node->valid = true;
node->parent = nullptr;
node->child[0] = node->child[1] = node->child[2] = nullptr;
node->num = mNum[i];
}
}
int add(int mId, int mNum, int mParent) {
Node* node = &NodePool[PoolCnt++];
Node* parent = hashFind(mParent);
if (parent->child[0] == nullptr) {
parent->child[0] = node;
} else if (parent->child[1] == nullptr) {
parent->child[1] = node;
} else if (parent->child[2] == nullptr) {
parent->child[2] = node;
} else {
return -1;
}
Node* curr = parent;
while (curr) {
curr->num += mNum;
curr = curr->parent;
}
hashAdd(mId, node);
node->valid = true;
node->parent = parent;
node->child[0] = node->child[1] = node->child[2] = nullptr;
node->num = mNum;
return parent->num;
}
int remove(int mId) {
Node* node = hashFind(mId);
if (!node || !node->valid) {
return -1;
}
Node* parent = node->parent;
if (parent->child[0] == node) {
parent->child[0] = nullptr;
} else if (parent->child[1] == node) {
parent->child[1] = nullptr;
} else if (parent->child[2] == node) {
parent->child[2] = nullptr;
}
Node* curr = parent;
while (curr) {
curr->num -= node->num;
curr = curr->parent;
}
delNode(node);
return node->num;
}
int distribute(int K) {
int low = 1, high = 0;
for (int i = 0; i < GroupCnt; ++i) {
if (NodePool[i].num > high)
high = NodePool[i].num;
}
while (low <= high) {
int mid = (low + high) / 2;
int sum = 0;
for (int i = 0; i < GroupCnt; ++i) {
if (NodePool[i].num <= mid)
sum += NodePool[i].num;
else
sum += mid;
}
if (sum <= K) {
low = mid + 1;
} else {
high = mid - 1;
}
}
return high;
}