1010_matmul_add_fuse_cutlass
user_3093867
c_cpp
a year ago
5.7 kB
6
Indexable
#include <iostream>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>;
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp
>;
// Helper function to print a few elements of a matrix
// Helper function to print a few elements of a matrix
template <typename T, typename Layout>
void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) {
std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl;
for (int i = 0; i < num_elements; ++i) {
typename Layout::TensorCoord coord;
if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) {
coord = typename Layout::TensorCoord{0, i};
} else {
coord = typename Layout::TensorCoord{i, 0};
}
std::cout << float(tensor.host_ref().at(coord)) << " ";
}
std::cout << std::endl;
}
int run_gemm(int m, int n, int k, float alpha, float beta) {
cutlass::gemm::GemmCoord problem_size(m, n, k);
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0);
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0);
cutlass::reference::host::TensorFill(tensor_c.host_view());
cutlass::reference::host::TensorFill(tensor_d.host_view());
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
typename Gemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{ElementAccumulator(alpha), ElementAccumulator(beta)}
};
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl;
return -1;
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run CUTLASS Gemm." << std::endl;
return -1;
}
// Print some elements of the input matrices
print_matrix_elements(tensor_a, "Matrix A");
print_matrix_elements(tensor_b, "Matrix B");
cudaDeviceSynchronize();
tensor_d.sync_host();
// Print some elements of the output matrix
print_matrix_elements(tensor_d, "Result Matrix D");
// cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn());
// cutlass::reference::host::Gemm<
// ElementInputA, LayoutInputA,
// ElementInputB, LayoutInputB,
// ElementOutput, LayoutOutput,
// ElementAccumulator, ElementAccumulator
// > reference_gemm;
// reference_gemm(
// problem_size,
// alpha,
// tensor_a.host_ref(),
// tensor_b.host_ref(),
// beta,
// tensor_c.host_ref(),
// reference_d.host_ref()
// );
// bool passed = cutlass::reference::host::TensorEquals(
// reference_d.host_view(),
// tensor_d.host_view()
// );
// std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl;
// return passed ? 0 : -1;
return 0;
}
int main() {
// int m = 8192;
// int n = 8192;
// int k = 8192;
// int m = 4096;
// int n = 4096;
// int k = 4096;
// int m = 2048;
// int n = 2048;
// int k = 2048;
// int m = 1024;
// int n = 1024;
// int k = 1024;
// int m = 512;
// int n = 512;
// int k = 512;
// int m = 256;
// int n = 256;
// int k = 256;
int m = 128;
int n = 128;
int k = 128;
float alpha = 1.0f;
float beta = 0.0f;
return run_gemm(m, n, k, alpha, beta);
}Editor is loading...
Leave a Comment