Untitled
unknown
plain_text
2 years ago
4.5 kB
11
Indexable
#include <iostream>
#include <cmath>
using namespace std;
// simple matrix multiplication
#define THRESHOLD 2
int nextPowerOfTwo(int n) {
// return pow(2, int(ceil(log2(n))));
int power = 1;
while (power < n) {
power *= 2;
}
return power;
}
class Matrix {
int rows, cols;
int view_rows, view_cols;
int **data;
public:
Matrix(int r, int c) {
rows = r;
cols = c;
view_rows = r;
view_cols = c;
data = new int*[rows];
for (int i = 0; i < rows; i++) {
data[i] = new int[cols];
}
}
~Matrix() {
for (int i = 0; i < rows; i++) {
delete [] data[i];
}
delete [] data;
}
void set(int i, int j, int val) {
data[i][j] = val;
}
int get(int i, int j) {
return data[i][j];
}
int getRows() {
return view_rows;
}
int getCols() {
return view_cols;
}
void setView(int r, int c) {
view_rows = r;
view_cols = c;
}
void print() {
for (int i = 0; i < rows; i++) {
cout << "[";
for (int j = 0; j < cols; j++) {
cout << data[i][j] << " ";
}
cout << "]" << endl;
}
}
};
Matrix* matmul(Matrix *a, Matrix *b) {
if (a->getCols() != b->getRows()) {
throw std::invalid_argument("Matrix dimensions do not match.");
}
Matrix *c = new Matrix(a->getRows(), b->getCols());
for (int i = 0; i < a->getRows(); i++) {
for (int j = 0; j < b->getCols(); j++) {
int sum = 0;
for (int k = 0; k < a->getCols(); k++) {
sum += a->get(i, k) * b->get(k, j);
}
c->set(i, j, sum);
}
}
return c;
}
// pad matrix
Matrix* pad(Matrix *a, int new_rows, int new_cols) {
Matrix *padded = new Matrix(new_rows, new_cols);
for (int i = 0; i < new_rows; i++) {
for (int j = 0; j < new_cols; j++) {
if (i < a->getRows() && j < a->getCols()) {
padded->set(i, j, a->get(i, j));
} else {
padded->set(i, j, 0);
}
}
}
}
// split matrix
void split_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int new_size) {
for (int i = 0; i < new_size; i++) {
for (int j = 0; j < new_size; j++) {
a11->set(i, j, a->get(i, j));
a12->set(i, j, a->get(i, j + new_size));
a21->set(i, j, a->get(i + new_size, j));
a22->set(i, j, a->get(i + new_size, j + new_size));
}
}
}
// merge matrix
void merge_matrix(Matrix *a, Matrix *a11, Matrix *a12, Matrix *a21, Matrix *a22, int big_size, int small_size) {
for (int i = 0; i < small_size; i++) {
for (int j = 0; j < small_size; j++) {
a->set(i, j, a11->get(i, j));
a->set(i, j + small_size, a12->get(i, j));
a->set(i + small_size, j, a21->get(i, j));
a->set(i + small_size, j + small_size, a22->get(i, j));
}
}
}
// add matrix
Matrix* add(Matrix *a, Matrix *b, int n) {
Matrix *c = new Matrix(n, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c->set(i, j, a->get(i, j) + b->get(i, j));
}
}
return c;
}
// subtract matrix
Matrix* sub(Matrix *a, Matrix *b, int n) {
Matrix *c = new Matrix(n, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c->set(i, j, a->get(i, j) - b->get(i, j));
}
}
return c;
}
Matrix* strassen(Matrix *a, Matrix *b) {
}
int main() {
const int A_ROWS = 16;
const int A_COLS = 32;
const int B_ROWS = 32;
const int B_COLS = 48;
Matrix *a = new Matrix(A_ROWS, A_COLS);
Matrix *b = new Matrix(B_ROWS, B_COLS);
// same seed
srand(2);
// init randomly
for (int i = 0; i < A_ROWS; i++) {
for (int j = 0; j < A_COLS; j++) {
a->set(i, j, rand() % 10);
}
}
for (int i = 0; i < B_ROWS; i++) {
for (int j = 0; j < B_COLS; j++) {
b->set(i, j, rand() % 10);
}
}
cout << "Main:" << endl;
cout << "Matrix A:" << endl;
a->print();
cout << "Matrix B:" << endl;
b->print();
Matrix *c_normal = matmul(a, b);
cout << "Matrix C normal:" << endl;
c_normal->print();
Matrix *c_strassen = strassen(a, b);
cout << "Matrix C strassen:" << endl;
c_strassen->print();
return 0;
}Editor is loading...