Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
8.9 kB
2
Indexable
Never
#include <iostream>
#include <cmath>
using namespace std;

// chrono
#include <chrono>

// simple matrix multiplication

#define THRESHOLD 64

int nextPowerOfTwo(int n) {
    // return pow(2, int(ceil(log2(n))));
    int power = 1;
    while (power < n) {
        power *= 2;
    }
    return power;
}

class Matrix {
    int rows, cols;
    int view_rows, view_cols;
    bool owner;
public:
    int *data;
    Matrix *parentMatrix;
    int offsetX, offsetY;
    Matrix(int r, int c) {
        rows = r;
        cols = c;
        view_rows = r;
        view_cols = c;
        data = new int[rows * cols];
        offsetX = 0;
        offsetY = 0;
        owner = true;
    }
    Matrix(int r, int c, Matrix* parentMatrix, int offsetX = 0, int offsetY = 0) {
        rows = r;
        cols = c;
        view_rows = r;
        view_cols = c;
        this->parentMatrix = parentMatrix;
        this->offsetX = offsetX;
        this->offsetY = offsetY;
        owner = false;
    }
    ~Matrix() {
        if (owner) {
            delete[] data;
        }
    }
    void set(int i, int j, int val) {
        if (i < 0 || i >= view_rows || j < 0 || j >= view_cols) {
            throw std::invalid_argument("Index out of bounds.");
        }
        if (!owner) {
            throw std::invalid_argument("Cannot modify non-owned matrix.");
        }
        data[i * cols + j] = val;
    }
    int get(int i, int j) {
        if (i < 0 || i >= view_rows || j < 0 || j >= view_cols) {
            throw std::invalid_argument("Index out of bounds.");
        }
        if (owner) {
            return data[i * cols + j];
        } else {
            return parentMatrix->get(offsetX + i, offsetY + j);
        }
    }
    int getRows() {
        return view_rows;
    }
    int getCols() {
        return view_cols;
    }
    void setView(int r, int c) {
        view_rows = r;
        view_cols = c;
    }
    void print() {
        for (int i = 0; i < view_rows; i++) {
            cout << "[";
            for (int j = 0; j < view_cols; j++) {
                cout << get(i, j) << " ";
            }
            cout << "]" << endl;
        }
    }
    bool equals(Matrix *other) {
        if (view_rows != other->getRows() || view_cols != other->getCols()) {
            return false;
        }
        for (int i = 0; i < view_rows; i++) {
            for (int j = 0; j < view_cols; j++) {
                if (get(i, j) != other->get(i, j)) {
                    return false;
                }
            }
        }
        return true;
    }
};

Matrix* matmul(Matrix *a, Matrix *b) {
    if (a->getCols() != b->getRows()) {
        throw std::invalid_argument("Matrix dimensions do not match.");
    }
    Matrix *c = new Matrix(a->getRows(), b->getCols());
    for (int i = 0; i < a->getRows(); i++) {
        for (int j = 0; j < b->getCols(); j++) {
            int sum = 0;
            for (int k = 0; k < a->getCols(); k++) {
                sum += a->get(i, k) * b->get(k, j);
            }
            c->set(i, j, sum);
        }
    }
    return c;
}

// pad matrix
Matrix* pad(Matrix *a, int new_rows, int new_cols) {
    Matrix *padded = new Matrix(new_rows, new_cols);
    for (int i = 0; i < new_rows; i++) {
        for (int j = 0; j < new_cols; j++) {
            if (i < a->getRows() && j < a->getCols()) {
                padded->set(i, j, a->get(i, j));
            } else {
                padded->set(i, j, 0);
            }
        }
    }
    return padded;
}

// split matrix
void split_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int new_size) {
    // for (int i = 0; i < new_size; i++) {
    //     for (int j = 0; j < new_size; j++) {
    //         a11->set(i, j, a->get(i, j));
    //         a12->set(i, j, a->get(i, j + new_size));
    //         a21->set(i, j, a->get(i + new_size, j));
    //         a22->set(i, j, a->get(i + new_size, j + new_size));
    //     }
    // }
    a11 = new Matrix(new_size, new_size, a, 0, 0);
    a12 = new Matrix(new_size, new_size, a, 0, new_size);
    a21 = new Matrix(new_size, new_size, a, new_size, 0);
    a22 = new Matrix(new_size, new_size, a, new_size, new_size);
}

// merge matrix
void merge_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int big_size, int small_size) {
    for (int i = 0; i < small_size; i++) {
        for (int j = 0; j < small_size; j++) {
            a->set(i, j, a11->get(i, j));
            a->set(i, j + small_size, a12->get(i, j));
            a->set(i + small_size, j, a21->get(i, j));
            a->set(i + small_size, j + small_size, a22->get(i, j));
        }
    }
}

// add matrix
Matrix* add(Matrix *a, Matrix *b, int n) {
    Matrix *c = new Matrix(n, n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            c->set(i, j, a->get(i, j) + b->get(i, j));
        }
    }
    return c;
}

// subtract matrix
Matrix* sub(Matrix *a, Matrix *b, int n) {
    Matrix *c = new Matrix(n, n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            c->set(i, j, a->get(i, j) - b->get(i, j));
        }
    }
    return c;
}

Matrix* strassen(Matrix *a, Matrix *b) {
    int max_dim = max(a->getRows(), max(a->getCols(), max(b->getRows(), b->getCols())));
    int n = nextPowerOfTwo(max_dim);

    // pad matrices
    Matrix *a_pad = pad(a, n, n);
    Matrix *b_pad = pad(b, n, n);

    if (a->getCols() != b->getRows()) {
        throw std::invalid_argument("Matrix dimensions do not match.");
    }
    if (n <= THRESHOLD) {
        return matmul(a_pad, b_pad);
    }

    int new_size = n / 2;

    // split matrices
    Matrix *a11;
    Matrix *a12;
    Matrix *a21;
    Matrix *a22;
    Matrix *b11;
    Matrix *b12;
    Matrix *b21;
    Matrix *b22;
    split_matrix(a_pad, a11, a12, a21, a22, new_size);
    split_matrix(b_pad, b11, b12, b21, b22, new_size);

    // compute p1-p7
    Matrix *p1 = strassen(a11, sub(b12, b22, new_size));
    Matrix *p2 = strassen(add(a11, a12, new_size), b22);
    Matrix *p3 = strassen(add(a21, a22, new_size), b11);
    Matrix *p4 = strassen(a22, sub(b21, b11, new_size));
    Matrix *p5 = strassen(add(a11, a22, new_size), add(b11, b22, new_size));
    Matrix *p6 = strassen(sub(a12, a22, new_size), add(b21, b22, new_size));
    Matrix *p7 = strassen(sub(a11, a21, new_size), add(b11, b12, new_size));

    // compute c11-c22
    Matrix *c11 = add(sub(add(p5, p4, new_size), p2, new_size), p6, new_size);
    Matrix *c12 = add(p1, p2, new_size);
    Matrix *c21 = add(p3, p4, new_size);
    Matrix *c22 = sub(sub(add(p5, p1, new_size), p3, new_size), p7, new_size);

    // merge matrices
    Matrix *c = new Matrix(n, n);
    merge_matrix(c, c11, c12, c21, c22, n, new_size);
    c->setView(a->getRows(), b->getCols());

    // free memory
    delete a_pad;
    delete b_pad;
    delete a11;
    delete a12;
    delete a21;
    delete a22;
    delete b11;
    delete b12;
    delete b21;
    delete b22;
    delete p1;
    delete p2;
    delete p3;
    delete p4;
    delete p5;
    delete p6;
    delete p7;
    delete c11;
    delete c12;
    delete c21;
    delete c22;

    return c;
}

int main() {
    const int A_ROWS = 1600;
    const int A_COLS = 320;
    const int B_ROWS = 320;
    const int B_COLS = 480;

    Matrix *a = new Matrix(A_ROWS, A_COLS);
    Matrix *b = new Matrix(B_ROWS, B_COLS);

    // same seed
    srand(2);

    // init randomly
    for (int i = 0; i < A_ROWS; i++) {
        for (int j = 0; j < A_COLS; j++) {
            a->set(i, j, rand() % 10);
        }
    }
    for (int i = 0; i < B_ROWS; i++) {
        for (int j = 0; j < B_COLS; j++) {
            b->set(i, j, rand() % 10);
        }
    }

    cout << "Main:" << endl;

    // cout << "Matrix A:" << endl;
    // a->print();

    // cout << "Matrix B:" << endl;
    // b->print();

    // Time it
    auto start = chrono::high_resolution_clock::now();
    Matrix *c_normal = matmul(a, b);
    auto end = chrono::high_resolution_clock::now();
    auto duration = chrono::duration_cast<chrono::microseconds>(end - start);
    cout << "Normal multiplication took " << duration.count() << " microseconds." << endl;
    // cout << "Matrix C normal:" << endl;
    // c_normal->print();

    start = chrono::high_resolution_clock::now();
    Matrix *c_strassen = strassen(a, b);
    end = chrono::high_resolution_clock::now();
    duration = chrono::duration_cast<chrono::microseconds>(end - start);
    cout << "Strassen multiplication took " << duration.count() << " microseconds." << endl;
    // cout << "Matrix C strassen:" << endl;
    // c_strassen->print();

    if (c_normal->equals(c_strassen)) {
        cout << "Matrices are EQUAL." << endl;
    } else {
        cout << "Matrices are NOT EQUAL." << endl;
    }

    delete a;
    delete b;
    delete c_normal;
    delete c_strassen;

    return 0;
}