Bitonic still fuck up a bit

 avatar
unknown
c_cpp
9 months ago
3.4 kB
2
Indexable
#include <bits/stdc++.h>
#include <mpi.h>

using namespace std;

void bitonic_sort(float* arr, int n, int up) {
    for (int k = 2; k <= n; k *= 2) {
        for (int j = k / 2; j > 0; j /= 2) {
            for (int i = 0; i < n; i++) {
                int l = i ^ j;
                if (l > i) {
                    if (((i & k) == 0 && arr[i] > arr[l]) || ((i & k) != 0 && arr[i] < arr[l])) {
                        std::swap(arr[i], arr[l]);
                    }
                }
            }
        }
    }
}

void bitonic_sort_mpi(float* chunk, int chunk_size, int rank, int size) {
    // Sort local chunk
    bitonic_sort(chunk, chunk_size, 1);
    // sort(chunk,chunk+chunk_size);

    for (int step = 1; step < size; step *= 2) {
        for (int stage = step; stage > 0; stage /= 2) {
            int partner = rank ^ stage;
            float temp[chunk_size];
            MPI_Sendrecv(chunk, chunk_size, MPI_FLOAT, partner, 0, temp, chunk_size, MPI_FLOAT, partner, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);

            if (rank < partner) {
                for (int i = 0; i < chunk_size; i++) {
                    if (chunk[i] > temp[i]) std::swap(chunk[i], temp[i]);
                }
            } else {
                for (int i = 0; i < chunk_size; i++) {
                    if (chunk[i] < temp[i]) std::swap(chunk[i], temp[i]);
                }
            }
            
        }
    }
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);

    int rank, size;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    if (argc != 4) {
        if (rank == 0) {
            std::cerr << "Usage: " << argv[0] << " <array size> <input file> <output file>\n";
        }
        MPI_Finalize();
        return 1;
    }

    int N = std::atoi(argv[1]);
    int padding_N = 1 << (int)std::ceil(std::log2(N));
    char* input_filename = argv[2];
    char* output_filename = argv[3];

    if (size & (size - 1)) {
        if (rank == 0) std::cerr << "Please use a process count that is a power of 2.\n";
        MPI_Finalize();
        return 1;
    }

    int chunksize = padding_N / size;
    float* chunk = new float[chunksize];
    int total_offset = chunksize * rank;

    MPI_File input_file, output_file;
    MPI_File_open(MPI_COMM_WORLD, input_filename, MPI_MODE_RDONLY, MPI_INFO_NULL, &input_file);


    if (total_offset + chunksize < N) {
        MPI_File_read_at(input_file, rank * chunksize * sizeof(float), chunk, chunksize, MPI_FLOAT, MPI_STATUS_IGNORE);
    } else {
        if (total_offset >= N) {
            std::fill(chunk, chunk + chunksize, FLT_MAX);
        } else {
            MPI_File_read_at(input_file, sizeof(float) * total_offset, chunk, N - total_offset, MPI_FLOAT, MPI_STATUS_IGNORE);
            std::fill(chunk + (N - total_offset), chunk + chunksize, FLT_MAX);
        }
    }

    MPI_File_close(&input_file);

    bitonic_sort_mpi(chunk, chunksize, rank, size);

    MPI_File_open(MPI_COMM_WORLD, output_filename, MPI_MODE_CREATE | MPI_MODE_WRONLY, MPI_INFO_NULL, &output_file);

    MPI_File_write_at(output_file, rank * chunksize * sizeof(float), chunk, std::min(chunksize, N - total_offset), MPI_FLOAT, MPI_STATUS_IGNORE);

    MPI_File_close(&output_file);

    delete[] chunk;
    MPI_Finalize();

    return 0;
}
Editor is loading...
Leave a Comment