1010_matmul_add

 avatar
user_3093867
c_cpp
5 months ago
27 kB
2
Indexable
// /***************************************************************************************************
//  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//  * SPDX-License-Identifier: BSD-3-Clause
//  *
//  * Redistribution and use in source and binary forms, with or without
//  * modification, are permitted provided that the following conditions are met:
//  *
//  * 1. Redistributions of source code must retain the above copyright notice, this
//  * list of conditions and the following disclaimer.
//  *
//  * 2. Redistributions in binary form must reproduce the above copyright notice,
//  * this list of conditions and the following disclaimer in the documentation
//  * and/or other materials provided with the distribution.
//  *
//  * 3. Neither the name of the copyright holder nor the names of its
//  * contributors may be used to endorse or promote products derived from
//  * this software without specific prior written permission.
//  *
//  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//  *
//  **************************************************************************************************/
// #include <cstdlib>
// #include <cstdio>
// #include <cassert>

// #include <thrust/host_vector.h>
// #include <thrust/device_vector.h>

// #include <cute/tensor.hpp>

// #include "cutlass/util/print_error.hpp"
// #include "cutlass/util/GPU_Clock.hpp"
// #include "cutlass/util/helper_cuda.hpp"

// template <class ProblemShape, class CtaTiler,
//           class TA, class AStride, class ASmemLayout, class AThreadLayout,
//           class TB, class BStride, class BSmemLayout, class BThreadLayout,
//           class TC, class CStride, class CSmemLayout, class CThreadLayout,
//           class Alpha, class Beta>
// __global__ static
// __launch_bounds__(decltype(size(CThreadLayout{}))::value)
// void
// gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
//             TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
//             TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
//             TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,
//             Alpha alpha, Beta beta)
// {
//   using namespace cute;

//   // Preconditions
//   CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)
//   CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)

//   static_assert(is_static<AThreadLayout>::value);
//   static_assert(is_static<BThreadLayout>::value);
//   static_assert(is_static<CThreadLayout>::value);

//   CUTE_STATIC_ASSERT_V(size(tA) == size(tB));                          // NumThreads
//   CUTE_STATIC_ASSERT_V(size(tC) == size(tA));                          // NumThreads

//   CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{});  // BLK_M / THR_M
//   CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{});  // BLK_K / THR_K
//   CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{});  // BLK_N / THR_N
//   CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{});  // BLK_K / THR_K
//   CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{});  // BLK_M / THR_M
//   CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{});  // BLK_N / THR_N

//   static_assert(is_static<ASmemLayout>::value);
//   static_assert(is_static<BSmemLayout>::value);
//   static_assert(is_static<CSmemLayout>::value);

//   CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_M
//   CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_M
//   CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
//   CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
//   CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_K
//   CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_K

//   CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MK
//   CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NK
//   CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN

//   //
//   // Full and Tiled Tensors
//   //

//   // Represent the full tensors
//   Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
//   Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
//   Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)

//   // Get the appropriate blocks for this thread block
//   auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)
//   Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)
//   Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)
//   Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)

//   // Shared memory buffers
//   __shared__ TA smemA[cosize_v<ASmemLayout>];
//   __shared__ TB smemB[cosize_v<BSmemLayout>];
//   Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)
//   Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K)

//   //
//   // Partition the copying of A and B tiles across the threads
//   //

//   // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles

//   Tensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)
//   Tensor tAsA = local_partition(sA, tA, threadIdx.x);                  // (THR_M,THR_K)

//   Tensor tBgB = local_partition(gB, tB, threadIdx.x);                  // (THR_N,THR_K,k)
//   Tensor tBsB = local_partition(sB, tB, threadIdx.x);                  // (THR_N,THR_K)

//   CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA));                // THR_M
//   CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA));                // THR_K
//   CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB));                // THR_N
//   CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB));                // THR_K

//   //
//   // Define A/B partitioning and C accumulators
//   //

//   // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC

//   // Partition sA (M,K) by the rows of tC
//   Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});   // (THR_M,BLK_K)
//   // Partition sB (N,K) by the cols of tC
//   Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{});   // (THR_N,BLK_K)
//   // Partition gC (M,N) by the tile of tC
//   Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});   // (THR_M,THR_N)

//   // Allocate the accumulators -- same shape/layout as the partitioned data
//   Tensor tCrC = make_tensor_like(tCgC);                                // (THR_M,THR_N)

//   CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC));                // THR_M
//   CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA));                // THR_M
//   CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC));                // THR_N
//   CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB));                // THR_N
//   CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB));                // BLK_K

//   // Clear the accumulators
//   clear(tCrC);

// #if 0
//   if(thread0()) {
//     print("  mA : "); print(  mA); print("\n");
//     print("  gA : "); print(  gA); print("\n");
//     print("  sA : "); print(  sA); print("\n");
//     print("tAgA : "); print(tAgA); print("\n");
//     print("tAsA : "); print(tAsA); print("\n");
//   }
// #endif

// #if 0
//   if(thread0()) {
//     print("  mB : "); print(  mB); print("\n");
//     print("  gB : "); print(  gB); print("\n");
//     print("  sB : "); print(  sB); print("\n");
//     print("tBgB : "); print(tBgB); print("\n");
//     print("tBsB : "); print(tBsB); print("\n");
//   }
// #endif

// #if 0
//   if(thread0()) {
//     print("  mC : "); print(  mC); print("\n");
//     print("  gC : "); print(  gC); print("\n");
//     print("tCsA : "); print(tCsA); print("\n");
//     print("tCsB : "); print(tCsB); print("\n");
//     print("tCgC : "); print(tCgC); print("\n");
//     print("tCrC : "); print(tCrC); print("\n");
//   }
// #endif

// #if 1

//   // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
//   //           and then computes on those tiles.
//   //   copy(.) operates on the global and shared memory via the tA|tB partitioning
//   //   gemm(.) operates on the shared and register memory via the tC partitioning

//   auto K_TILE_MAX = size<2>(tAgA);

//   for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
//   {
//     // Copy gmem to smem with tA|tB thread-partitioned tensors
//     copy(tAgA(_,_,k_tile), tAsA);      // A   (THR_M,THR_K) -> (THR_M,THR_K)
//     copy(tBgB(_,_,k_tile), tBsB);      // B   (THR_N,THR_K) -> (THR_N,THR_K)

//     // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
//     //   Tensor tAgAk = tAgA(_,_,k_tile);
//     //   CUTE_UNROLL
//     //   for (int i = 0; i < size(tAsA); ++i) {
//     //     tAsA(i) = tAgAk(i);
//     //   }

//     cp_async_fence();        // Label the end of (potential) cp.async instructions
//     cp_async_wait<0>();      // Sync on all (potential) cp.async instructions
//     __syncthreads();         // Wait for all threads to write to smem

//     // Compute gemm on tC thread-partitioned smem
//     gemm(tCsA, tCsB, tCrC);            // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)

//     // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
//     //   CUTE_UNROLL
//     //   for (int k = 0; k < size<1>(tCsA); ++k) {
//     //     CUTE_UNROLL
//     //     for (int m = 0; m < size<0>(tCrC); ++m) {
//     //       CUTE_UNROLL
//     //       for (int n = 0; n < size<1>(tCrC); ++n) {
//     //         tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
//     //       }
//     //     }
//     //   }

//     __syncthreads();         // Wait for all threads to read from smem
//   }

// #endif

//   //
//   // Epilogue
//   //

//   axpby(alpha, tCrC, beta, tCgC);

//   // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
//   //   CUTE_UNROLL
//   //   for (int i = 0; i < size(tCsA); ++i) {
//   //     tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
//   //   }
// }


//   // Elementwise addition kernel
// template <typename T>
// __global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) {
//     int row = blockIdx.y * blockDim.y + threadIdx.y;
//     int col = blockIdx.x * blockDim.x + threadIdx.x;
    
//     if (row < M && col < N) {
//         int idx = row * N + col;
//         C[idx] += D[idx];
//     }
// }

// // Setup params for an NT GEMM
// // Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB
// template <class TA, class TB, class TC,
//           class Alpha, class Beta>
// void
// gemm_nt(int m, int n, int k,
//         Alpha alpha,
//         TA const* A, int ldA,
//         TB const* B, int ldB,
//         Beta beta,
//         TC      * C, int ldC,
//         cudaStream_t stream = 0)
// {
//   using namespace cute;

//   // Define shapes (dynamic)
//   auto M = int(m);
//   auto N = int(n);
//   auto K = int(k);
//   auto prob_shape = make_shape(M, N, K);                     // (M, N, K)

//   // Define NT strides (mixed)
//   auto dA = make_stride(Int<1>{}, ldA);                      // (dM, dK)
//   auto dB = make_stride(Int<1>{}, ldB);                      // (dN, dK)
//   auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)

//   // Define CTA tile sizes (static)
//   auto bM = Int<128>{};
//   auto bN = Int<128>{};
//   auto bK = Int<  8>{};
//   auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)

//   // Define the smem layouts (static)
//   auto sA = make_layout(make_shape(bM, bK));                 // (m,k) -> smem_idx; m-major
//   auto sB = make_layout(make_shape(bN, bK));                 // (n,k) -> smem_idx; n-major
//   auto sC = make_layout(make_shape(bM, bN));                 // (m,n) -> smem_idx; m-major

//   // Define the thread layouts (static)
//   auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));   // (m,k) -> thr_idx
//   auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));   // (n,k) -> thr_idx
//   auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));   // (m,n) -> thr_idx

//   dim3 dimBlock(size(tC));
//   dim3 dimGrid(size(ceil_div(M, bM)),
//                size(ceil_div(N, bN)));
//   gemm_device<<<dimGrid, dimBlock, 0, stream>>>
//       (prob_shape, cta_tiler,
//        A, dA, sA, tA,
//        B, dB, sB, tB,
//        C, dC, sC, tC,
//        alpha, beta);




  



// }

// // Setup params for a TN GEMM
// // Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB
// template <class TA, class TB, class TC,
//           class Alpha, class Beta>
// void
// gemm_tn(int m, int n, int k,
//         Alpha alpha,
//         TA const* A, int ldA,
//         TB const* B, int ldB,
//         Beta beta,
//         TC      * C, int ldC,
//         cudaStream_t stream = 0)
// {
//   using namespace cute;

//   // Define shapes (dynamic)
//   auto M = int(m);
//   auto N = int(n);
//   auto K = int(k);
//   auto prob_shape = make_shape(M, N, K);                     // (M, N, K)

//   // Define TN strides (mixed)
//   auto dA = make_stride(ldA, Int<1>{});                      // (dM, dK)
//   auto dB = make_stride(ldB, Int<1>{});                      // (dN, dK)
//   auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)

//   // Define CTA tile sizes (static)
//   auto bM = Int<128>{};
//   auto bN = Int<128>{};
//   auto bK = Int<  8>{};
//   auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)

//   // Define the smem layouts (static)
//   auto sA = make_layout(make_shape(bM,bK), LayoutRight{});   // (m,k) -> smem_idx; k-major
//   auto sB = make_layout(make_shape(bN,bK), LayoutRight{});   // (n,k) -> smem_idx; k-major
//   auto sC = make_layout(make_shape(bM, bN));                 // (m,n) -> smem_idx; m-major

//   // Define the thread layouts (static)
//   auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{});  // (m,k) -> thr_idx; k-major
//   auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{});  // (n,k) -> thr_idx; k-major
//   auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));                 // (m,n) -> thr_idx; m-major

//   dim3 dimBlock(size(tC));
//   dim3 dimGrid(size(ceil_div(M, bM)),
//                size(ceil_div(N, bN)));
//   gemm_device<<<dimGrid, dimBlock, 0, stream>>>
//       (prob_shape, cta_tiler,
//        A, dA, sA, tA,
//        B, dB, sB, tB,
//        C, dC, sC, tC,
//        alpha, beta);
// }

// template <class TA, class TB, class TC, class TD,
//           class Alpha, class Beta>
// void
// gemm(char transA, char transB, int m, int n, int k,
//      Alpha alpha,
//      TA const* A, int ldA,
//      TB const* B, int ldB,
//      Beta beta,
//      TC      * C, int ldC,
//      TD const* D,  // New parameter for the D matrix
//      cudaStream_t stream = 0)
// {
//   if (transA == 'N' && transB == 'T') {
//     gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
//   } else if (transA == 'T' && transB == 'N') {
//     gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
//   } else {
//     assert(false && "Not implemented");
//     return;
//   }

//   // Launch elementwise addition kernel
//   dim3 blockDim(32, 32);
//   dim3 gridDim((n + blockDim.x - 1) / blockDim.x, 
//                (m + blockDim.y - 1) / blockDim.y);
//   elementwiseAddKernel<<<gridDim, blockDim, 0, stream>>>(C, D, m, n);
// }


// int main(int argc, char** argv)
// {
//   int m = 128;
//   if (argc >= 2)
//     sscanf(argv[1], "%d", &m);

//   int n = 128;
//   if (argc >= 3)
//     sscanf(argv[2], "%d", &n);

//   int k = 128;
//   if (argc >= 4)
//     sscanf(argv[3], "%d", &k);

//   char transA = 'N';
//   if (argc >= 5)
//     sscanf(argv[4], "%c", &transA);

//   char transB = 'T';
//   if (argc >= 6)
//     sscanf(argv[5], "%c", &transB);

//   using TA = float;
//   using TB = float;
//   using TC = float;
//   using TI = float;

//   TI alpha = 1.0;
//   TI beta  = 0.0;

//   std::cout << "M = " << m << std::endl;
//   std::cout << "N = " << n << std::endl;
//   std::cout << "K = " << k << std::endl;
//   std::cout << "C = A^" << transA << " B^" << transB << std::endl;

//   cute::device_init(0);

//   thrust::host_vector<TA> h_A(m*k);
//   thrust::host_vector<TB> h_B(n*k);
//   thrust::host_vector<TC> h_C(m*n);

//   for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
//   for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
//   for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);

//   thrust::device_vector<TA> d_A = h_A;
//   thrust::device_vector<TB> d_B = h_B;
//   thrust::device_vector<TC> d_C = h_C;

//   // Add D matrix
//   thrust::host_vector<TC> h_D(m*n);
//   for (int j = 0; j < m*n; ++j) h_D[j] = static_cast<TC>( 2*(rand() / double(RAND_MAX)) - 1 );
//   thrust::device_vector<TC> d_D = h_D;

//   double gflops = (2.0*m*n*k) * 1e-9;

//   const int timing_iterations = 100;
//   GPU_Clock timer;

//   int ldA = 0, ldB = 0, ldC = m;

//   if (transA == 'N') {
//     ldA = m;
//   } else if (transA == 'T') {
//     ldA = k;
//   } else {
//     assert(false);
//   }

//   if (transB == 'N') {
//     ldB = k;
//   } else if (transB == 'T') {
//     ldB = n;
//   } else {
//     assert(false);
//   }
//   // Run once
//   d_C = h_C;
//   gemm(transA, transB, m, n, k,
//        alpha,
//        d_A.data().get(), ldA,
//        d_B.data().get(), ldB,
//        beta,
//        d_C.data().get(), ldC,
//        d_D.data().get());  // Add D matrix to the GEMM call
//   CUTE_CHECK_LAST();
//   // thrust::host_vector<TC> cute_result = d_C;

//   // Timing iterations
//   // timer.start();
//   // for (int i = 0; i < timing_iterations; ++i) {
//   //   gemm(transA, transB, m, n, k,
//   //        alpha,
//   //        d_A.data().get(), ldA,
//   //        d_B.data().get(), ldB,
//   //        beta,
//   //        d_C.data().get(), ldC,
//   //        d_D.data().get());  // Add D matrix to the GEMM call
//   // }
//   // double cute_time = timer.seconds() / timing_iterations;
//   // CUTE_CHECK_LAST();
//   // printf("CUTE_GEMM + ElementwiseAdd: [%6.1f]GFlop/s  (%6.4f)ms\n", gflops / cute_time, cute_time*1000);

//   return 0;
// }

#include <iostream>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;

using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;

using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;

using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    ElementOutput,
    128 / cutlass::sizeof_bits<ElementOutput>::value,
    ElementAccumulator,
    ElementAccumulator
>;

using Gemm = cutlass::gemm::device::Gemm<
    ElementInputA,
    LayoutInputA,
    ElementInputB,
    LayoutInputB,
    ElementOutput,
    LayoutOutput,
    ElementAccumulator,
    MMAOp,
    SmArch,
    ShapeMMAThreadBlock,
    ShapeMMAWarp,
    ShapeMMAOp,
    EpilogueOp
>;

// Elementwise addition kernel
template <typename T>
__global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (row < M && col < N) {
        int idx = row * N + col;
        C[idx] += D[idx];
    }
}

// Helper function to print a few elements of a matrix
template <typename T, typename Layout>
void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) {
    std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl;
    for (int i = 0; i < num_elements; ++i) {
        typename Layout::TensorCoord coord;
        if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) {
            coord = typename Layout::TensorCoord{0, i};
        } else {
            coord = typename Layout::TensorCoord{i, 0};
        }
        std::cout << float(tensor.host_ref().at(coord)) << " ";
    }
    std::cout << std::endl;
}

int run_gemm(int m, int n, int k, float alpha, float beta) {
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_e(problem_size.mn());  // New tensor for elementwise addition

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);

    cutlass::reference::host::TensorFill(tensor_c.host_view());
    cutlass::reference::host::TensorFill(tensor_d.host_view());
    cutlass::reference::host::TensorFillRandomUniform(  // Fill tensor_e with random values
        tensor_e.host_view(),
        1,
        ElementOutput(1),
        ElementOutput(-1),
        0);

    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();
    tensor_e.sync_device();

    typename Gemm::Arguments arguments{
        problem_size,
        tensor_a.device_ref(),
        tensor_b.device_ref(),
        tensor_c.device_ref(),
        tensor_d.device_ref(),
        {ElementAccumulator(alpha), ElementAccumulator(beta)}
    };

    Gemm gemm_op;
    
    cutlass::Status status = gemm_op.initialize(arguments);

    if (status != cutlass::Status::kSuccess) {
        std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl;
        return -1;
    }

    status = gemm_op();
    
    if (status != cutlass::Status::kSuccess) {
        std::cerr << "Failed to run CUTLASS Gemm." << std::endl;
        return -1;
    }

    // Launch elementwise addition kernel
    dim3 block(32, 32);
    dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y);
    elementwiseAddKernel<<<grid, block>>>(
        reinterpret_cast<cutlass::half_t*>(tensor_d.device_data()),
        reinterpret_cast<cutlass::half_t*>(tensor_e.device_data()),
        m, n
    );

    // Print some elements of the input matrices
    print_matrix_elements(tensor_a, "Matrix A");
    print_matrix_elements(tensor_b, "Matrix B");
    print_matrix_elements(tensor_e, "Matrix E (for elementwise addition)");

    cudaDeviceSynchronize();
    tensor_d.sync_host();

    // Print some elements of the output matrix
    print_matrix_elements(tensor_d, "Result Matrix D (after GEMM and elementwise addition)");

    // cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn());
    // cutlass::reference::host::Gemm<
    //     ElementInputA, LayoutInputA,
    //     ElementInputB, LayoutInputB,
    //     ElementOutput, LayoutOutput,
    //     ElementAccumulator, ElementAccumulator
    // > reference_gemm;

    // reference_gemm(
    //     problem_size,
    //     alpha,
    //     tensor_a.host_ref(),
    //     tensor_b.host_ref(),
    //     beta,
    //     tensor_c.host_ref(),
    //     reference_d.host_ref()
    // );

    // // Perform elementwise addition on CPU for reference
    // for (int i = 0; i < m * n; ++i) {
    //     reference_d.host_ref().data()[i] += tensor_e.host_ref().data()[i];
    // }

    // bool passed = cutlass::reference::host::TensorEquals(
    //     reference_d.host_view(),
    //     tensor_d.host_view()
    // );

    // std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl;

    // return passed ? 0 : -1;
    return 0;
}

int main() {
    // int m = 8192;
    // int n = 8192;
    // int k = 8192;
    // int m = 4096;
    // int n = 4096;
    // int k = 4096;
    // int m = 2048;
    // int n = 2048;
    // int k = 2048;
    // int m = 1024;
    // int n = 1024;
    // int k = 1024;
    // int m = 512;
    // int n = 512;
    // int k = 512;
    // int m = 256;
    // int n = 256;
    // int k = 256;
    int m = 128;
    int n = 128;
    int k = 128;
    float alpha = 1.0f;
    float beta = 0.0f;

    return run_gemm(m, n, k, alpha, beta);
}
Editor is loading...
Leave a Comment