Untitled
unknown
plain_text
a year ago
8.9 kB
2
Indexable
Never
#include <iostream> #include <cmath> using namespace std; // chrono #include <chrono> // simple matrix multiplication #define THRESHOLD 64 int nextPowerOfTwo(int n) { // return pow(2, int(ceil(log2(n)))); int power = 1; while (power < n) { power *= 2; } return power; } class Matrix { int rows, cols; int view_rows, view_cols; bool owner; public: int *data; Matrix *parentMatrix; int offsetX, offsetY; Matrix(int r, int c) { rows = r; cols = c; view_rows = r; view_cols = c; data = new int[rows * cols]; offsetX = 0; offsetY = 0; owner = true; } Matrix(int r, int c, Matrix* parentMatrix, int offsetX = 0, int offsetY = 0) { rows = r; cols = c; view_rows = r; view_cols = c; this->parentMatrix = parentMatrix; this->offsetX = offsetX; this->offsetY = offsetY; owner = false; } ~Matrix() { if (owner) { delete[] data; } } void set(int i, int j, int val) { if (i < 0 || i >= view_rows || j < 0 || j >= view_cols) { throw std::invalid_argument("Index out of bounds."); } if (!owner) { throw std::invalid_argument("Cannot modify non-owned matrix."); } data[i * cols + j] = val; } int get(int i, int j) { if (i < 0 || i >= view_rows || j < 0 || j >= view_cols) { throw std::invalid_argument("Index out of bounds."); } if (owner) { return data[i * cols + j]; } else { return parentMatrix->get(offsetX + i, offsetY + j); } } int getRows() { return view_rows; } int getCols() { return view_cols; } void setView(int r, int c) { view_rows = r; view_cols = c; } void print() { for (int i = 0; i < view_rows; i++) { cout << "["; for (int j = 0; j < view_cols; j++) { cout << get(i, j) << " "; } cout << "]" << endl; } } bool equals(Matrix *other) { if (view_rows != other->getRows() || view_cols != other->getCols()) { return false; } for (int i = 0; i < view_rows; i++) { for (int j = 0; j < view_cols; j++) { if (get(i, j) != other->get(i, j)) { return false; } } } return true; } }; Matrix* matmul(Matrix *a, Matrix *b) { if (a->getCols() != b->getRows()) { throw std::invalid_argument("Matrix dimensions do not match."); } Matrix *c = new Matrix(a->getRows(), b->getCols()); for (int i = 0; i < a->getRows(); i++) { for (int j = 0; j < b->getCols(); j++) { int sum = 0; for (int k = 0; k < a->getCols(); k++) { sum += a->get(i, k) * b->get(k, j); } c->set(i, j, sum); } } return c; } // pad matrix Matrix* pad(Matrix *a, int new_rows, int new_cols) { Matrix *padded = new Matrix(new_rows, new_cols); for (int i = 0; i < new_rows; i++) { for (int j = 0; j < new_cols; j++) { if (i < a->getRows() && j < a->getCols()) { padded->set(i, j, a->get(i, j)); } else { padded->set(i, j, 0); } } } return padded; } // split matrix void split_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int new_size) { // for (int i = 0; i < new_size; i++) { // for (int j = 0; j < new_size; j++) { // a11->set(i, j, a->get(i, j)); // a12->set(i, j, a->get(i, j + new_size)); // a21->set(i, j, a->get(i + new_size, j)); // a22->set(i, j, a->get(i + new_size, j + new_size)); // } // } a11 = new Matrix(new_size, new_size, a, 0, 0); a12 = new Matrix(new_size, new_size, a, 0, new_size); a21 = new Matrix(new_size, new_size, a, new_size, 0); a22 = new Matrix(new_size, new_size, a, new_size, new_size); } // merge matrix void merge_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int big_size, int small_size) { for (int i = 0; i < small_size; i++) { for (int j = 0; j < small_size; j++) { a->set(i, j, a11->get(i, j)); a->set(i, j + small_size, a12->get(i, j)); a->set(i + small_size, j, a21->get(i, j)); a->set(i + small_size, j + small_size, a22->get(i, j)); } } } // add matrix Matrix* add(Matrix *a, Matrix *b, int n) { Matrix *c = new Matrix(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c->set(i, j, a->get(i, j) + b->get(i, j)); } } return c; } // subtract matrix Matrix* sub(Matrix *a, Matrix *b, int n) { Matrix *c = new Matrix(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c->set(i, j, a->get(i, j) - b->get(i, j)); } } return c; } Matrix* strassen(Matrix *a, Matrix *b) { int max_dim = max(a->getRows(), max(a->getCols(), max(b->getRows(), b->getCols()))); int n = nextPowerOfTwo(max_dim); // pad matrices Matrix *a_pad = pad(a, n, n); Matrix *b_pad = pad(b, n, n); if (a->getCols() != b->getRows()) { throw std::invalid_argument("Matrix dimensions do not match."); } if (n <= THRESHOLD) { return matmul(a_pad, b_pad); } int new_size = n / 2; // split matrices Matrix *a11; Matrix *a12; Matrix *a21; Matrix *a22; Matrix *b11; Matrix *b12; Matrix *b21; Matrix *b22; split_matrix(a_pad, a11, a12, a21, a22, new_size); split_matrix(b_pad, b11, b12, b21, b22, new_size); // compute p1-p7 Matrix *p1 = strassen(a11, sub(b12, b22, new_size)); Matrix *p2 = strassen(add(a11, a12, new_size), b22); Matrix *p3 = strassen(add(a21, a22, new_size), b11); Matrix *p4 = strassen(a22, sub(b21, b11, new_size)); Matrix *p5 = strassen(add(a11, a22, new_size), add(b11, b22, new_size)); Matrix *p6 = strassen(sub(a12, a22, new_size), add(b21, b22, new_size)); Matrix *p7 = strassen(sub(a11, a21, new_size), add(b11, b12, new_size)); // compute c11-c22 Matrix *c11 = add(sub(add(p5, p4, new_size), p2, new_size), p6, new_size); Matrix *c12 = add(p1, p2, new_size); Matrix *c21 = add(p3, p4, new_size); Matrix *c22 = sub(sub(add(p5, p1, new_size), p3, new_size), p7, new_size); // merge matrices Matrix *c = new Matrix(n, n); merge_matrix(c, c11, c12, c21, c22, n, new_size); c->setView(a->getRows(), b->getCols()); // free memory delete a_pad; delete b_pad; delete a11; delete a12; delete a21; delete a22; delete b11; delete b12; delete b21; delete b22; delete p1; delete p2; delete p3; delete p4; delete p5; delete p6; delete p7; delete c11; delete c12; delete c21; delete c22; return c; } int main() { const int A_ROWS = 1600; const int A_COLS = 320; const int B_ROWS = 320; const int B_COLS = 480; Matrix *a = new Matrix(A_ROWS, A_COLS); Matrix *b = new Matrix(B_ROWS, B_COLS); // same seed srand(2); // init randomly for (int i = 0; i < A_ROWS; i++) { for (int j = 0; j < A_COLS; j++) { a->set(i, j, rand() % 10); } } for (int i = 0; i < B_ROWS; i++) { for (int j = 0; j < B_COLS; j++) { b->set(i, j, rand() % 10); } } cout << "Main:" << endl; // cout << "Matrix A:" << endl; // a->print(); // cout << "Matrix B:" << endl; // b->print(); // Time it auto start = chrono::high_resolution_clock::now(); Matrix *c_normal = matmul(a, b); auto end = chrono::high_resolution_clock::now(); auto duration = chrono::duration_cast<chrono::microseconds>(end - start); cout << "Normal multiplication took " << duration.count() << " microseconds." << endl; // cout << "Matrix C normal:" << endl; // c_normal->print(); start = chrono::high_resolution_clock::now(); Matrix *c_strassen = strassen(a, b); end = chrono::high_resolution_clock::now(); duration = chrono::duration_cast<chrono::microseconds>(end - start); cout << "Strassen multiplication took " << duration.count() << " microseconds." << endl; // cout << "Matrix C strassen:" << endl; // c_strassen->print(); if (c_normal->equals(c_strassen)) { cout << "Matrices are EQUAL." << endl; } else { cout << "Matrices are NOT EQUAL." << endl; } delete a; delete b; delete c_normal; delete c_strassen; return 0; }