Untitled
unknown
plain_text
a year ago
5.9 kB
2
Indexable
Never
#include <iostream> using namespace std; // simple matrix multiplication #define THRESHOLD 2 class Matrix { int rows, cols; int **data; public: Matrix(int r, int c) { rows = r; cols = c; data = new int*[rows]; for (int i = 0; i < rows; i++) { data[i] = new int[cols]; } } ~Matrix() { for (int i = 0; i < rows; i++) { delete [] data[i]; } delete [] data; } void set(int i, int j, int val) { data[i][j] = val; } int get(int i, int j) { return data[i][j]; } int getRows() { return rows; } int getCols() { return cols; } void print() { for (int i = 0; i < rows; i++) { cout << "["; for (int j = 0; j < cols; j++) { cout << data[i][j] << " "; } cout << "]" << endl; } } }; Matrix* matmul(Matrix *a, Matrix *b) { if (a->getCols() != b->getRows()) { throw std::invalid_argument("Matrix dimensions do not match."); } Matrix *c = new Matrix(a->getRows(), b->getCols()); for (int i = 0; i < a->getRows(); i++) { for (int j = 0; j < b->getCols(); j++) { int sum = 0; for (int k = 0; k < a->getCols(); k++) { sum += a->get(i, k) * b->get(k, j); } c->set(i, j, sum); } } return c; } // split matrix void split_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int n) { for (int i1 = 0, i2 = n / 2; i1 < n / 2; i1++, i2++) { for (int j1 = 0, j2 = n / 2; j1 < n / 2; j1++, j2++) { a11->set(i1, j1, a->get(i1, j1)); a12->set(i1, j1, a->get(i1, j2)); a21->set(i1, j1, a->get(i2, j1)); a22->set(i1, j1, a->get(i2, j2)); } } } // merge matrix void merge_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int n) { for (int i1 = 0, i2 = n / 2; i1 < n / 2; i1++, i2++) { for (int j1 = 0, j2 = n / 2; j1 < n / 2; j1++, j2++) { a->set(i1, j1, a11->get(i1, j1)); a->set(i1, j2, a12->get(i1, j1)); a->set(i2, j1, a21->get(i1, j1)); a->set(i2, j2, a22->get(i1, j1)); } } } // add matrix Matrix* add(Matrix *a, Matrix *b, int n) { Matrix *c = new Matrix(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c->set(i, j, a->get(i, j) + b->get(i, j)); } } return c; } // subtract matrix Matrix* sub(Matrix *a, Matrix *b, int n) { Matrix *c = new Matrix(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c->set(i, j, a->get(i, j) - b->get(i, j)); } } return c; } Matrix* strassen(Matrix *a, Matrix *b, int n) { if (a->getCols() != b->getRows()) { throw std::invalid_argument("Matrix dimensions do not match."); } if (n <= THRESHOLD) { return matmul(a, b); } int new_size = n / 2; Matrix *a11 = new Matrix(new_size, new_size); Matrix *a12 = new Matrix(new_size, new_size); Matrix *a21 = new Matrix(new_size, new_size); Matrix *a22 = new Matrix(new_size, new_size); Matrix *b11 = new Matrix(new_size, new_size); Matrix *b12 = new Matrix(new_size, new_size); Matrix *b21 = new Matrix(new_size, new_size); Matrix *b22 = new Matrix(new_size, new_size); // Split the matrices into four equal-sized sub-matrices split_matrix(a, a11, a12, a21, a22, n); split_matrix(b, b11, b12, b21, b22, n); // Calculate the 7 products of the sub-matrices using Strassen's algorithm Matrix *m1 = strassen(add(a11, a22, new_size), add(b11, b22, new_size), new_size); Matrix *m2 = strassen(add(a21, a22, new_size), b11, new_size); Matrix *m3 = strassen(a11, sub(b12, b22, new_size), new_size); Matrix *m4 = strassen(a22, sub(b21, b11, new_size), new_size); Matrix *m5 = strassen(add(a11, a12, new_size), b22, new_size); Matrix *m6 = strassen(sub(a21, a11, new_size), add(b11, b12, new_size), new_size); Matrix *m7 = strassen(sub(a12, a22, new_size), add(b21, b22, new_size), new_size); // Combine the results into the output matrix Matrix *c11 = add(sub(add(m1, m4, new_size), m5, new_size), m7, new_size); Matrix *c12 = add(m3, m5, new_size); Matrix *c21 = add(m2, m4, new_size); Matrix *c22 = add(sub(add(m1, m3, new_size), m2, new_size), m6, new_size); // Merge the results into a single matrix Matrix *c = new Matrix(a->getRows(), b->getCols()); merge_matrix(c, c11, c12, c21, c22, n); // Deallocate memory for temporary matrices delete a11; delete a12; delete a21; delete a22; delete b11; delete b12; delete b21; delete b22; delete m1; delete m2; delete m3; delete m4; delete m5; delete m6; delete m7; delete c11; delete c12; delete c21; delete c22; return c; } int main() { const int A_ROWS = 16; const int A_COLS = 16; const int B_ROWS = 16; const int B_COLS = 16; Matrix *a = new Matrix(A_ROWS, A_COLS); Matrix *b = new Matrix(B_ROWS, B_COLS); // same seed srand(0); // init randomly for (int i = 0; i < A_ROWS; i++) { for (int j = 0; j < A_COLS; j++) { a->set(i, j, rand() % 10); } } for (int i = 0; i < B_ROWS; i++) { for (int j = 0; j < B_COLS; j++) { b->set(i, j, rand() % 10); } } cout << "Matrix A:" << endl; a->print(); cout << "Matrix B:" << endl; b->print(); Matrix *c_normal = matmul(a, b); Matrix *c_strassen = strassen(a, b, A_ROWS); cout << "Matrix C normal:" << endl; c_normal->print(); cout << "Matrix C strassen:" << endl; c_strassen->print(); return 0; }