Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
5.9 kB
2
Indexable
Never
#include <iostream>
using namespace std;

// simple matrix multiplication

#define THRESHOLD 2

class Matrix {
    int rows, cols;
    int **data;
public:
    Matrix(int r, int c) {
        rows = r;
        cols = c;
        data = new int*[rows];
        for (int i = 0; i < rows; i++) {
            data[i] = new int[cols];
        }
    }
    ~Matrix() {
        for (int i = 0; i < rows; i++) {
            delete [] data[i];
        }
        delete [] data;
    }
    void set(int i, int j, int val) {
        data[i][j] = val;
    }
    int get(int i, int j) {
        return data[i][j];
    }
    int getRows() {
        return rows;
    }
    int getCols() {
        return cols;
    }
    void print() {
        for (int i = 0; i < rows; i++) {
            cout << "[";
            for (int j = 0; j < cols; j++) {
                cout << data[i][j] << " ";
            }
            cout << "]" << endl;
        }
    }
};

Matrix* matmul(Matrix *a, Matrix *b) {
    if (a->getCols() != b->getRows()) {
        throw std::invalid_argument("Matrix dimensions do not match.");
    }
    Matrix *c = new Matrix(a->getRows(), b->getCols());
    for (int i = 0; i < a->getRows(); i++) {
        for (int j = 0; j < b->getCols(); j++) {
            int sum = 0;
            for (int k = 0; k < a->getCols(); k++) {
                sum += a->get(i, k) * b->get(k, j);
            }
            c->set(i, j, sum);
        }
    }
    return c;
}

// split matrix
void split_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int n) {
    for (int i1 = 0, i2 = n / 2; i1 < n / 2; i1++, i2++) {
        for (int j1 = 0, j2 = n / 2; j1 < n / 2; j1++, j2++) {
            a11->set(i1, j1, a->get(i1, j1));
            a12->set(i1, j1, a->get(i1, j2));
            a21->set(i1, j1, a->get(i2, j1));
            a22->set(i1, j1, a->get(i2, j2));
        }
    }
}

// merge matrix
void merge_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int n) {
    for (int i1 = 0, i2 = n / 2; i1 < n / 2; i1++, i2++) {
        for (int j1 = 0, j2 = n / 2; j1 < n / 2; j1++, j2++) {
            a->set(i1, j1, a11->get(i1, j1));
            a->set(i1, j2, a12->get(i1, j1));
            a->set(i2, j1, a21->get(i1, j1));
            a->set(i2, j2, a22->get(i1, j1));
        }
    }
}

// add matrix
Matrix* add(Matrix *a, Matrix *b, int n) {
    Matrix *c = new Matrix(n, n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            c->set(i, j, a->get(i, j) + b->get(i, j));
        }
    }
    return c;
}

// subtract matrix
Matrix* sub(Matrix *a, Matrix *b, int n) {
    Matrix *c = new Matrix(n, n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            c->set(i, j, a->get(i, j) - b->get(i, j));
        }
    }
    return c;
}

Matrix* strassen(Matrix *a, Matrix *b, int n) {
    if (a->getCols() != b->getRows()) {
        throw std::invalid_argument("Matrix dimensions do not match.");
    }
    if (n <= THRESHOLD) {
        return matmul(a, b);
    }

    int new_size = n / 2;

    Matrix *a11 = new Matrix(new_size, new_size);
    Matrix *a12 = new Matrix(new_size, new_size);
    Matrix *a21 = new Matrix(new_size, new_size);
    Matrix *a22 = new Matrix(new_size, new_size);

    Matrix *b11 = new Matrix(new_size, new_size);
    Matrix *b12 = new Matrix(new_size, new_size);
    Matrix *b21 = new Matrix(new_size, new_size);
    Matrix *b22 = new Matrix(new_size, new_size);

    // Split the matrices into four equal-sized sub-matrices
    split_matrix(a, a11, a12, a21, a22, n);
    split_matrix(b, b11, b12, b21, b22, n);

    // Calculate the 7 products of the sub-matrices using Strassen's algorithm
    Matrix *m1 = strassen(add(a11, a22, new_size), add(b11, b22, new_size), new_size);
    Matrix *m2 = strassen(add(a21, a22, new_size), b11, new_size);
    Matrix *m3 = strassen(a11, sub(b12, b22, new_size), new_size);
    Matrix *m4 = strassen(a22, sub(b21, b11, new_size), new_size);
    Matrix *m5 = strassen(add(a11, a12, new_size), b22, new_size);
    Matrix *m6 = strassen(sub(a21, a11, new_size), add(b11, b12, new_size), new_size);
    Matrix *m7 = strassen(sub(a12, a22, new_size), add(b21, b22, new_size), new_size);

    // Combine the results into the output matrix
    Matrix *c11 = add(sub(add(m1, m4, new_size), m5, new_size), m7, new_size);
    Matrix *c12 = add(m3, m5, new_size);
    Matrix *c21 = add(m2, m4, new_size);
    Matrix *c22 = add(sub(add(m1, m3, new_size), m2, new_size), m6, new_size);

    // Merge the results into a single matrix
    Matrix *c = new Matrix(a->getRows(), b->getCols());
    merge_matrix(c, c11, c12, c21, c22, n);

    // Deallocate memory for temporary matrices
    delete a11;
    delete a12;
    delete a21;
    delete a22;

    delete b11;
    delete b12;
    delete b21;
    delete b22;

    delete m1;
    delete m2;
    delete m3;
    delete m4;
    delete m5;
    delete m6;
    delete m7;

    delete c11;
    delete c12;
    delete c21;
    delete c22;

    return c;
}

int main() {
    const int A_ROWS = 16;
    const int A_COLS = 16;
    const int B_ROWS = 16;
    const int B_COLS = 16;

    Matrix *a = new Matrix(A_ROWS, A_COLS);
    Matrix *b = new Matrix(B_ROWS, B_COLS);

    // same seed
    srand(0);

    // init randomly
    for (int i = 0; i < A_ROWS; i++) {
        for (int j = 0; j < A_COLS; j++) {
            a->set(i, j, rand() % 10);
        }
    }
    for (int i = 0; i < B_ROWS; i++) {
        for (int j = 0; j < B_COLS; j++) {
            b->set(i, j, rand() % 10);
        }
    }

    cout << "Matrix A:" << endl;
    a->print();

    cout << "Matrix B:" << endl;
    b->print();

    Matrix *c_normal = matmul(a, b);
    Matrix *c_strassen = strassen(a, b, A_ROWS);

    cout << "Matrix C normal:" << endl;
    c_normal->print();

    cout << "Matrix C strassen:" << endl;
    c_strassen->print();

    return 0;
}