1010_matmul_add_fuse_cutlass

 avatar
user_3093867
c_cpp
5 months ago
5.7 kB
2
Indexable
#include <iostream>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;

using MMAOp = cutlass::arch::OpClassTensorOp;

using SmArch = cutlass::arch::Sm75;

using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;

using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    ElementOutput,
    128 / cutlass::sizeof_bits<ElementOutput>::value,
    ElementAccumulator,
    ElementAccumulator
>;

using Gemm = cutlass::gemm::device::Gemm<
    ElementInputA,
    LayoutInputA,
    ElementInputB,
    LayoutInputB,
    ElementOutput,
    LayoutOutput,
    ElementAccumulator,
    MMAOp,
    SmArch,
    ShapeMMAThreadBlock,
    ShapeMMAWarp,
    ShapeMMAOp,
    EpilogueOp
>;

// Helper function to print a few elements of a matrix
// Helper function to print a few elements of a matrix
template <typename T, typename Layout>
void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) {
    std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl;
    for (int i = 0; i < num_elements; ++i) {
        typename Layout::TensorCoord coord;
        if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) {
            coord = typename Layout::TensorCoord{0, i};
        } else {
            coord = typename Layout::TensorCoord{i, 0};
        }
        std::cout << float(tensor.host_ref().at(coord)) << " ";
    }
    std::cout << std::endl;
}


int run_gemm(int m, int n, int k, float alpha, float beta) {
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);

    cutlass::reference::host::TensorFill(tensor_c.host_view());
    cutlass::reference::host::TensorFill(tensor_d.host_view());

    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();

    typename Gemm::Arguments arguments{
        problem_size,
        tensor_a.device_ref(),
        tensor_b.device_ref(),
        tensor_c.device_ref(),
        tensor_d.device_ref(),
        {ElementAccumulator(alpha), ElementAccumulator(beta)}
    };

    Gemm gemm_op;
    
    cutlass::Status status = gemm_op.initialize(arguments);

    if (status != cutlass::Status::kSuccess) {
        std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl;
        return -1;
    }

    status = gemm_op();
    
    if (status != cutlass::Status::kSuccess) {
        std::cerr << "Failed to run CUTLASS Gemm." << std::endl;
        return -1;
    }


    // Print some elements of the input matrices
    print_matrix_elements(tensor_a, "Matrix A");
    print_matrix_elements(tensor_b, "Matrix B");


    cudaDeviceSynchronize();
    tensor_d.sync_host();

    // Print some elements of the output matrix
    print_matrix_elements(tensor_d, "Result Matrix D");


    // cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn());
    // cutlass::reference::host::Gemm<
    //     ElementInputA, LayoutInputA,
    //     ElementInputB, LayoutInputB,
    //     ElementOutput, LayoutOutput,
    //     ElementAccumulator, ElementAccumulator
    // > reference_gemm;

    // reference_gemm(
    //     problem_size,
    //     alpha,
    //     tensor_a.host_ref(),
    //     tensor_b.host_ref(),
    //     beta,
    //     tensor_c.host_ref(),
    //     reference_d.host_ref()
    // );

    // bool passed = cutlass::reference::host::TensorEquals(
    //     reference_d.host_view(),
    //     tensor_d.host_view()
    // );

    

    // std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl;

    // return passed ? 0 : -1;
    return 0;
}

int main() {
    // int m = 8192;
    // int n = 8192;
    // int k = 8192;
    // int m = 4096;
    // int n = 4096;
    // int k = 4096;
    // int m = 2048;
    // int n = 2048;
    // int k = 2048;
    // int m = 1024;
    // int n = 1024;
    // int k = 1024;
    //     int m = 512;
    // int n = 512;
    // int k = 512;
    // int m = 256;
    // int n = 256;
    // int k = 256;
    int m = 128;
    int n = 128;
    int k = 128;
    float alpha = 1.0f;
    float beta = 0.0f;

    return run_gemm(m, n, k, alpha, beta);
}
Editor is loading...
Leave a Comment