Untitled

 avatar
unknown
c_cpp
a year ago
5.6 kB
10
Indexable
#include <mpi.h>
#include <bits/stdc++.h>

using namespace std;
using namespace chrono;

#define max_value 10

int N;

void multiply_matrix(vector<int> &A_sub, vector<int> &B_sub, vector<int> &C_sub, int block_size){
    int rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    cout << "Process " << rank << " performing local matrix multiplication\n";
    /*cout << "A: ";
    for (int i = 0; i < block_size; i++){
        cout << A_sub[i] << ' ';
    }
    cout << '\n';
    cout << "B: ";
    for (int i = 0; i < block_size; i++){
        cout << B_sub[i] << ' ';
    }
    cout << '\n';*/
    for (int i = 0; i < block_size; ++i){
        for (int j = 0; j < block_size; ++j){
            int sum = 0;
            for (int k = 0; k < block_size; ++k){
                sum += A_sub[i * block_size + k] * B_sub[k * block_size + j];
            }
            C_sub[i * block_size + j] += sum;
        }
        }
    cout << "Process " << rank << " local matrix multiplication ends\n";
    return;
}

void matrix_init(vector<int> &matrix){
    for (int i = 0; i < matrix.size(); i++){
        matrix[i] = static_cast<int>(drand48() * max_value);
    }
    return;
}

bool verify(vector<int> &expected, vector<int> &result){
    int n = expected.size();
    for (int i = 0; i < n; i++){
        if (expected[i] != result[i])
            return false;
    }
    return true;
}

int main(int argc, char **argv){
    if (argc < 2){
        cerr << "matrix size should be typed\n";
        return -1;
    }
    bool vf = false;
    N = atoi(argv[1]);
    if (argc > 2 && strcmp(argv[2], "-v") == 0)
        vf = true;
    MPI_Init(&argc, &argv);
    int p, rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &p);
    int c = 1; // parameter for memory replication
    /*for (int i = 1; i <= p; ++i){
        if (p % (i * i) == 0)
            c = i * i; // 'c' is udimated if 'p' is divisible by 'i * i'.
    }*/
    int q = static_cast<int>(sqrt(p / c));
    if (q * q != p / c){
        cout << "Processor number not perfect square!\n";
        MPI_Finalize();
        return -1;
    }
    int dims[3] = {q, q, c};
    int periods[3] = {1, 1, 1};
    MPI_Comm grid_comm;
    MPI_Cart_create(MPI_COMM_WORLD, 3, dims, periods, 1, &grid_comm);

    int block_size = N / q;

    vector<int> A(N * N), B(N * N);
    vector<int> A_sub(block_size * block_size), B_sub(block_size * block_size);
    vector<int> C_serial(N * N, 0), C(N * N, 0), C_sub(block_size * block_size, 0);
    if (rank == 0){
        matrix_init(A);
        matrix_init(B);
        if(vf)
            multiply_matrix(A, B, C_serial, block_size);
    }

    MPI_Barrier(MPI_COMM_WORLD);

    MPI_Datatype block_type, block_type_resized;
    MPI_Type_vector(block_size, block_size, N, MPI_INT, &block_type);
    MPI_Type_create_resized(block_type, 0, block_size * sizeof(int), &block_type_resized);
    MPI_Type_commit(&block_type_resized);

    cout << rank << " Type commit\n";

    MPI_Scatter(A.data(), 1, block_type_resized, A_sub.data(), block_size * block_size, MPI_INT, 0, MPI_COMM_WORLD);//need to !!
    MPI_Scatter(B.data(), 1, block_type_resized, B_sub.data(), block_size * block_size, MPI_INT, 0, MPI_COMM_WORLD); // need to !!

    cout << rank << " Scatter\n";

    multiply_matrix(A_sub, B_sub, C_sub, block_size);
    int right_src, right_dst, down_src, down_dst;
    MPI_Cart_shift(grid_comm, 1, 1, &right_src, &right_dst);
    MPI_Cart_shift(grid_comm, 0, 1, &down_src, &down_dst);

    cout << rank << " Initial Shift\n";

    MPI_Status status;
    for (int i = 0; i < c; i++){
        MPI_Sendrecv_replace(A_sub.data(), block_size * block_size, MPI_INT, right_dst, 0, right_src, 0, grid_comm, &status);
        MPI_Sendrecv_replace(B_sub.data(), block_size * block_size, MPI_INT, down_dst, 1, down_src, 1, grid_comm, &status);
        multiply_matrix(A_sub, B_sub, C_sub, block_size);
    }

    cout << rank << " Algorithm\n";

    MPI_Barrier(MPI_COMM_WORLD);

    vector<int> recvcounts(p, block_size * block_size), displs(p);
    if (rank == 0){
        for (int i = 0; i < displs.size(); ++i){
            int row_block_index = i / q; // Row index in the block grid
            int col_block_index = i % q; // Column index in the block grid
            displs[i] = (row_block_index * block_size * N) + (col_block_index * block_size);
        }
    }

    MPI_Barrier(MPI_COMM_WORLD);

    MPI_Gatherv(
        C_sub.data(),            // Send buffer: the subblock from this process
        block_size * block_size, // Number of elements in send buffer
        MPI_INT,                 // Type of send buffer elements
        C.data(),                // Receive buffer (only on root)
        recvcounts.data(),       // Amount of data to receive from each process (only on root)
        displs.data(),           // Displacements where to place incoming data (only on root)
        MPI_INT,                 // Type of receive buffer elements
        0,                       // Root process
        MPI_COMM_WORLD           // Communicator
    );
    
    cout << rank << " Gather\n";

    MPI_Barrier(MPI_COMM_WORLD);

    if (rank == 0 && vf){
        bool result = verify(C_serial, C);
        if (result == false)
            cout << "Incorrect!\n";
        else
            cout << "Correct!\n";
    }

    MPI_Type_free(&block_type_resized);
    MPI_Type_free(&block_type);
    MPI_Comm_free(&grid_comm);
    MPI_Finalize();

    return 0;
}
Editor is loading...
Leave a Comment