1010_matmul_add_fuse_cutlass
user_3093867
c_cpp
5 months ago
5.7 kB
2
Indexable
#include <iostream> #include <cuda_runtime.h> #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/gemm.h" using ElementInputA = cutlass::half_t; using ElementInputB = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = float; using LayoutInputA = cutlass::layout::RowMajor; using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::ColumnMajor; using MMAOp = cutlass::arch::OpClassTensorOp; using SmArch = cutlass::arch::Sm75; using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>; using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, ElementAccumulator, ElementAccumulator >; using Gemm = cutlass::gemm::device::Gemm< ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, ShapeMMAWarp, ShapeMMAOp, EpilogueOp >; // Helper function to print a few elements of a matrix // Helper function to print a few elements of a matrix template <typename T, typename Layout> void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) { std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl; for (int i = 0; i < num_elements; ++i) { typename Layout::TensorCoord coord; if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) { coord = typename Layout::TensorCoord{0, i}; } else { coord = typename Layout::TensorCoord{i, 0}; } std::cout << float(tensor.host_ref().at(coord)) << " "; } std::cout << std::endl; } int run_gemm(int m, int n, int k, float alpha, float beta) { cutlass::gemm::GemmCoord problem_size(m, n, k); cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk()); cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn()); cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn()); cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn()); cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, ElementInputA(4), ElementInputA(-4), 0); cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, ElementInputB(4), ElementInputB(-4), 0); cutlass::reference::host::TensorFill(tensor_c.host_view()); cutlass::reference::host::TensorFill(tensor_d.host_view()); tensor_a.sync_device(); tensor_b.sync_device(); tensor_c.sync_device(); tensor_d.sync_device(); typename Gemm::Arguments arguments{ problem_size, tensor_a.device_ref(), tensor_b.device_ref(), tensor_c.device_ref(), tensor_d.device_ref(), {ElementAccumulator(alpha), ElementAccumulator(beta)} }; Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl; return -1; } status = gemm_op(); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Gemm." << std::endl; return -1; } // Print some elements of the input matrices print_matrix_elements(tensor_a, "Matrix A"); print_matrix_elements(tensor_b, "Matrix B"); cudaDeviceSynchronize(); tensor_d.sync_host(); // Print some elements of the output matrix print_matrix_elements(tensor_d, "Result Matrix D"); // cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn()); // cutlass::reference::host::Gemm< // ElementInputA, LayoutInputA, // ElementInputB, LayoutInputB, // ElementOutput, LayoutOutput, // ElementAccumulator, ElementAccumulator // > reference_gemm; // reference_gemm( // problem_size, // alpha, // tensor_a.host_ref(), // tensor_b.host_ref(), // beta, // tensor_c.host_ref(), // reference_d.host_ref() // ); // bool passed = cutlass::reference::host::TensorEquals( // reference_d.host_view(), // tensor_d.host_view() // ); // std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl; // return passed ? 0 : -1; return 0; } int main() { // int m = 8192; // int n = 8192; // int k = 8192; // int m = 4096; // int n = 4096; // int k = 4096; // int m = 2048; // int n = 2048; // int k = 2048; // int m = 1024; // int n = 1024; // int k = 1024; // int m = 512; // int n = 512; // int k = 512; // int m = 256; // int n = 256; // int k = 256; int m = 128; int n = 128; int k = 128; float alpha = 1.0f; float beta = 0.0f; return run_gemm(m, n, k, alpha, beta); }
Editor is loading...
Leave a Comment