#include <omp.h>

#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <fstream>
#include <iostream>
#include <queue>
#include <vector>

#include "collision.h"
#include "io.h"
#include "sim_validator.h"

 * ObjectId consists of
 * bit 30 - 31: Home bit
 * bit 26 - 29: Control bits
 * bit 0 - 25: Particle ID
class ObjectId {
    unsigned int objectId;

    static constexpr unsigned int homeBitMask = 0b11 << 30;
    static constexpr unsigned int controlBitMask = 0b1111 << 26;
    static constexpr unsigned int particleIdBitMask = (1 << 26) - 1;

    ObjectId() : objectId(0) {}

    unsigned int getHomeBits() const { return (objectId & homeBitMask) >> 30; }
    unsigned int getControlBits() const { return (objectId & controlBitMask) >> 26; }
    unsigned int getParticleId() const { return objectId & particleIdBitMask; }

    void setHomeBits(unsigned int homeBit) { objectId = (objectId & ~homeBitMask) | (homeBit << 30); }
    void setControlBits(unsigned int controlBits) { objectId = (objectId & ~controlBitMask) | (controlBits << 26); }
    void setParticleId(int particleId) { objectId = (objectId & ~particleIdBitMask) | particleId; }

    void print() { printf("objectId: %d | %d | %d\n", getHomeBits(), getControlBits(), getParticleId()); }

struct CellCollision {
    int cellId;
    int numHomeCells;
    int numPhantomCells;
    ObjectId* homeCellPtr;
    ObjectId* phantomCellPtr;

    void print() {
        std::cout << "====Cell Collision====" << "\n";
        std::cout << "Cell ID: " << cellId << "\n";
        std::cout << "Num Home Cells: " << numHomeCells << "\n";
        for (int i = 0; i < numHomeCells; i++) {
            std::cout << homeCellPtr[i].getParticleId() << " ";
        std::cout << "\n";
        std::cout << "Num Phantom Cells: " << numPhantomCells << "\n";
        for (int i = 0; i < numPhantomCells; i++) {
            std::cout << phantomCellPtr[i].getParticleId() << " ";
        std::cout << "\n";

inline int determineCellType(int xCoord, int yCoord) {
    int xMod = xCoord % 2;
    int yMod = yCoord % 2;
    return yMod << 1 | xMod;

inline int determineCellTypeFromIndex(int cellIndex, int cellsPerSide) {
    int xCoord = cellIndex % cellsPerSide;
    int yCoord = cellIndex / cellsPerSide;
    return determineCellType(xCoord, yCoord);

inline bool isSideInBounds(int coord, int cellsPerSide) { return (coord >= 0 && coord < cellsPerSide); }

inline int getCellFromCoords(int xCoord, int yCoord, int cellsPerSide) { return (yCoord * cellsPerSide) + xCoord; }

int main(int argc, char* argv[]) {
    // Read arguments and input file
    Params params{};
    std::vector<Particle> particles;
    read_args(argc, argv, params, particles);

    // Set number of threads
#if CHECK == 1
    // Initialize collision checker
    SimulationValidator validator(params.param_particles, params.square_size, params.param_radius);
    // Initialize with starting positions
    // Uncomment the line below to enable visualization (makes program much
    // slower) validator.enable_viz_output("test.out");

    // code start here
    int cellLength = params.param_radius * 2 * 1.5;
    int cellsPerSide = params.square_size / cellLength;

    for (int i = 0; i < params.param_steps; i++) {
        std::vector<int> cellIdArray(params.param_particles * 4, -1);

        std::vector<ObjectId> objectIdArray(params.param_particles * 4, ObjectId());

        // update all particle's position, and add to cells
#pragma omp parallel for
        for (int j = 0; j < params.param_particles; j++) {
            Particle* particle = &(particles[j]);
            particle->loc.x += particle->vel.x;
            particle->loc.y += particle->vel.y;

            double xPos = particle->loc.x / cellLength;
            double yPos = particle->loc.y / cellLength;

            int xCoord = std::max(0, std::min(static_cast<int>(xPos), cellsPerSide - 1));
            int yCoord = std::max(0, std::min(static_cast<int>(yPos), cellsPerSide - 1));

            int cellIndex = getCellFromCoords(xCoord, yCoord, cellsPerSide);

            // update cellIdArray and objectIdArray
            cellIdArray[j * 4] = cellIndex;                              // h-cell
            cellIdArray[(j * 4) + 1] = std::numeric_limits<int>::max();  // p-cell for left-right overlap
            cellIdArray[(j * 4) + 2] = std::numeric_limits<int>::max();  // p-cell for top-bottom overlap
            cellIdArray[(j * 4) + 3] = std::numeric_limits<int>::max();  // p-cell for diagonal overlap

            int controlBits = 0;

            int homeCellType = determineCellType(xCoord, yCoord);
            controlBits |= 1 << (homeCellType);

            double xDisplacementWithinCell = particle->loc.x - (xCoord * cellLength);
            double yDisplacementWithinCell = particle->loc.y - (yCoord * cellLength);

            int xShift = 0;
            int yShift = 0;

            if (xDisplacementWithinCell < params.param_radius) {
                xShift = -1;
            } else if (xDisplacementWithinCell > cellLength - params.param_radius) {
                xShift = 1;

            if (yDisplacementWithinCell < params.param_radius) {
                yShift = -1;
            } else if (yDisplacementWithinCell > cellLength - params.param_radius) {
                yShift = 1;

            // check for left-right overlap
            if (xShift) {
                if (isSideInBounds(xCoord + xShift, cellsPerSide)) {
                    int xCellIndex = (cellIndex + xShift);
                    cellIdArray[(j * 4) + 1] = xCellIndex;
                    int xCellType = determineCellType(xCoord + xShift, yCoord);
                    // int xCellType = determineCellTypeFromIndex(xCellIndex, cellsPerSide);
                    controlBits |= (1 << xCellType);

            // check for top-bottom overlap
            if (yShift) {
                if (isSideInBounds(yCoord + yShift, cellsPerSide)) {
                    int yCellIndex = cellIndex + (cellsPerSide * yShift);
                    cellIdArray[(j * 4) + 2] = yCellIndex;
                    int yCellType = determineCellType(xCoord, yCoord + yShift);
                    // int yCellType = determineCellType(xCoord, yCoord + yShift); // Corrected line
                    // int yCellType = determineCellType(xCoord, yCoord - 1);
                    controlBits |= (1 << yCellType);

            // check for diagonal overlap
            if (xShift && yShift) {
                if (isSideInBounds(xCoord + xShift, cellsPerSide) && isSideInBounds(yCoord + yShift, cellsPerSide)) {
                    int xyCellIndex = cellIndex + (cellsPerSide * yShift) + xShift;
                    cellIdArray[(j * 4) + 3] = xyCellIndex;
                    int xyCellType = determineCellType(xCoord + xShift, yCoord + yShift);
                    // int xyCellType = determineCellTypeFromIndex(xyCellIndex, cellsPerSide);
                    controlBits |= (1 << xyCellType);
            if (j == 439 || j == 3602) {
                std::cout << "home cell type is: " << homeCellType << "control bits are: " << controlBits << " particle index is:" << j << "\n";
            ObjectId objectId = ObjectId();
            for (int k = 0; k < 4; k++) {
                objectIdArray[j * 4 + k] = objectId;

            if (i == 10) {
                printf("in building arrays: particleId: %d\n", j);
                for (int k = 0; k < 4; k++) {
                    printf("cellId: %d\n", cellIdArray[j * 4 + k]);
                    objectIdArray[j * 4 + k].print();

        // for every 4 elements, (e.g. idx 0, 4, 8, ...)
        // move first element into a new H array
        // move the next 3 elements into a new P array
        std::vector<int> hCellIdArray;
        std::vector<ObjectId> hObjectIdArray;

        std::vector<int> pCellIdArray;
        pCellIdArray.reserve(params.param_particles * 3);
        std::vector<ObjectId> pObjectIdArray;
        pObjectIdArray.reserve(params.param_particles * 3);

        // Separate the home and phantom cells
        for (int j = 0; j < params.param_particles; j++) {
            int hIdx = j * 4;  // Home cell index (first element in each 4-element block)

            // Home cell
            int cellId = cellIdArray[hIdx];
            assert(cellIdArray[hIdx] != std::numeric_limits<int>::max());
            assert(objectIdArray[hIdx].getHomeBits() ==
                   static_cast<unsigned int>(determineCellTypeFromIndex(cellId, cellsPerSide)));


            // Phantom cells (next 3 elements)
            for (int k = 1; k < 4; ++k) {
                int pIdx = hIdx + k;  // Phantom cell index
                if (cellIdArray[pIdx] != std::numeric_limits<int>::max()) {

        // Concatenate the H array first, then the P array
        cellIdArray.insert(cellIdArray.end(), hCellIdArray.begin(), hCellIdArray.end());
        cellIdArray.insert(cellIdArray.end(), pCellIdArray.begin(), pCellIdArray.end());

        objectIdArray.insert(objectIdArray.end(), hObjectIdArray.begin(), hObjectIdArray.end());
        objectIdArray.insert(objectIdArray.end(), pObjectIdArray.begin(), pObjectIdArray.end());

        // Perform the radix sort to sort by cellId
        constexpr int bitsPerPass = 8;
        constexpr int numBuckets = 1 << bitsPerPass;                // 256 buckets for 8 bits
        constexpr int numPasses = (sizeof(int) * 8) / bitsPerPass;  // Number of passes (4 passes for 32-bit int)

        int numElements = cellIdArray.size();
        std::vector<int> sortedCellIdArray(numElements);
        std::vector<ObjectId> sortedObjectIdArray(numElements);

        // Iterate over each bit pass
        for (int pass = 0; pass < numPasses; ++pass) {
            int shift = pass * bitsPerPass;  // Determine the current bit shift

            // Initialize count array for the current pass
            std::vector<int> count(numBuckets, 0);

// Count occurrences of each bucket
#pragma omp parallel
                std::vector<int> localCount(numBuckets, 0);

#pragma omp for nowait
                for (int i = 0; i < numElements; ++i) {
                    int bucket = (cellIdArray[i] >> shift) & (numBuckets - 1);

#pragma omp critical
                    for (int i = 0; i < numBuckets; i++) {
                        count[i] += localCount[i];

            // Compute prefix sum to get positions
            for (int i = 1; i < numBuckets; ++i) {
                count[i] += count[i - 1];

            // Sort based on current digit
            for (int i = numElements - 1; i >= 0; --i) {
                int bucket = (cellIdArray[i] >> shift) & (numBuckets - 1);
                sortedCellIdArray[--count[bucket]] = cellIdArray[i];
                sortedObjectIdArray[count[bucket]] = objectIdArray[i];

            // Copy sorted arrays back for the next pass
            std::swap(cellIdArray, sortedCellIdArray);
            std::swap(objectIdArray, sortedObjectIdArray);

        if (i == 10) {
            printf("-----sorted arrays-----\n");
            for (int t = 0; t < std::ssize(cellIdArray); t++) {
                printf("cellId: %d\n", cellIdArray[t]);
                bool isHome = objectIdArray[t].getHomeBits() ==
                              static_cast<unsigned int>(determineCellTypeFromIndex(cellIdArray[t], cellsPerSide));
                printf("isHome: %d\n", isHome);
            printf("-----sorted arrays end-----\n");

        std::vector<CellCollision> cellCollisionList;

        // check cellIdArray
        // if new cellId -> if it is a home cell -> add to collision cell list
        for (int j = 0; j < std::ssize(cellIdArray); j++) {
            int cellId = cellIdArray[j];
            if (cellId == std::numeric_limits<int>::max()) break;

            int cellType = determineCellTypeFromIndex(cellId, cellsPerSide);
            ObjectId objectId = objectIdArray[j];

            if (i == 10) {
                printf("creating collisionlist: cellId: %d\n", cellId);

            if (objectId.getHomeBits() == static_cast<unsigned int>(cellType)) {
                int numH = 0;
                int numP = 0;
                ObjectId* hPtr = &(objectIdArray[j]);
                ObjectId* pPtr = nullptr;
                while (cellIdArray[j] == cellId) {
                    if (objectIdArray[j].getHomeBits() == static_cast<unsigned int>(cellType)) {
                    } else {
                        if (pPtr == nullptr) {
                            pPtr = &(objectIdArray[j]);
                j--;  // wacky, replace this
                cellCollisionList.emplace_back(cellId, numH, numP, hPtr, pPtr);
            } else {

        bool collisionsOccurred = true;
        while (collisionsOccurred) {
            collisionsOccurred = false;

#pragma omp parallel for
            for (int j = 0; j < std::ssize(cellCollisionList); j++) {
                CellCollision cc = cellCollisionList[j];

                // debug
                if (i == 10) {

                // process home cells wall collision
                for (int k = 0; k < cc.numHomeCells; k++) {
                    auto objectId = cc.homeCellPtr[k];

                    int particleIndex = objectId.getParticleId();
                    Particle* particle = &(particles[particleIndex]);

                    if (is_wall_collision(particle->loc, particle->vel, params.square_size, params.param_radius)) {
                        resolve_wall_collision(particle->loc, particle->vel, params.square_size, params.param_radius);
                        collisionsOccurred = true;

                // process home cell - home cell collision and home cell - phantom cell collision
                for (int k = 0; k < cc.numHomeCells; k++) {
                    auto objectId = cc.homeCellPtr[k];
                    int curParticleIndex = objectId.getParticleId();
                    Particle* curParticle = &(particles[curParticleIndex]);

                    // home cell - home cell collision
                    for (int l = k + 1; l < cc.numHomeCells; l++) {
                        auto otherObjectId = cc.homeCellPtr[l];

                        int otherParticleIndex = otherObjectId.getParticleId();
                        if ((curParticleIndex == 439 && otherParticleIndex == 3602) || (otherParticleIndex == 439 && curParticleIndex == 3602)) { 
                            std::cout << "home home here" << "\n";
                        Particle* otherParticle = &(particles[otherParticleIndex]);

                        if (is_particle_collision(curParticle->loc, curParticle->vel, otherParticle->loc,
                                                  otherParticle->vel, params.param_radius)) {
                            resolve_particle_collision(curParticle->loc, curParticle->vel, otherParticle->loc,
                            collisionsOccurred = true;

                    // home cell - phantom cell collision
                    for (int l = 0; l < cc.numPhantomCells; l++) {
                        auto otherObjectId = cc.phantomCellPtr[l];

                        // int particleHomeCell = objectId.getHomeBits();
                        // int otherHomeCell = otherObjectId.getHomeBits();
                        // int otherControlBits = otherObjectId.getControlBits();

                        // // bool otherHasParticle = (otherControlBits & (1 << particleHomeCell)); // not necessary
                        // bool particleHasOther = (objectId.getControlBits() & (1 << otherHomeCell));

                        // if ((otherHomeCell > particleHomeCell) && particleHasOther) {
                        //     if (i == 10) {
                        //         printf("skipping: particle: %d, other: %d, pHome: %d, oHome: %d, oControl: %d\n",
                        //                curParticleIndex, otherObjectId.getParticleId(), otherHomeCell,
                        //                particleHomeCell, otherControlBits);
                        //     }
                        //     continue;
                        // }

                        int otherParticleIndex = otherObjectId.getParticleId();
                        Particle* otherParticle = &(particles[otherParticleIndex]);
                        if ((curParticleIndex == 439 && otherParticleIndex == 3602) || (otherParticleIndex == 439 && curParticleIndex == 3602)) { 
                            std::cout << "home phantom here" << "\n";
                        if (is_particle_collision(curParticle->loc, curParticle->vel, otherParticle->loc,
                                                  otherParticle->vel, params.param_radius)) {
                            resolve_particle_collision(curParticle->loc, curParticle->vel, otherParticle->loc,
                            collisionsOccurred = true;
#if CHECK == 1

// code end here
#if CHECK == 1
// Check final positions
// validator.validate_step(particles);
