1010_matmul_add
user_3093867
c_cpp
a year ago
27 kB
14
Indexable
// /***************************************************************************************************
// * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// * SPDX-License-Identifier: BSD-3-Clause
// *
// * Redistribution and use in source and binary forms, with or without
// * modification, are permitted provided that the following conditions are met:
// *
// * 1. Redistributions of source code must retain the above copyright notice, this
// * list of conditions and the following disclaimer.
// *
// * 2. Redistributions in binary form must reproduce the above copyright notice,
// * this list of conditions and the following disclaimer in the documentation
// * and/or other materials provided with the distribution.
// *
// * 3. Neither the name of the copyright holder nor the names of its
// * contributors may be used to endorse or promote products derived from
// * this software without specific prior written permission.
// *
// * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// *
// **************************************************************************************************/
// #include <cstdlib>
// #include <cstdio>
// #include <cassert>
// #include <thrust/host_vector.h>
// #include <thrust/device_vector.h>
// #include <cute/tensor.hpp>
// #include "cutlass/util/print_error.hpp"
// #include "cutlass/util/GPU_Clock.hpp"
// #include "cutlass/util/helper_cuda.hpp"
// template <class ProblemShape, class CtaTiler,
// class TA, class AStride, class ASmemLayout, class AThreadLayout,
// class TB, class BStride, class BSmemLayout, class BThreadLayout,
// class TC, class CStride, class CSmemLayout, class CThreadLayout,
// class Alpha, class Beta>
// __global__ static
// __launch_bounds__(decltype(size(CThreadLayout{}))::value)
// void
// gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
// TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
// TC * C, CStride dC, CSmemLayout , CThreadLayout tC,
// Alpha alpha, Beta beta)
// {
// using namespace cute;
// // Preconditions
// CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
// CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
// static_assert(is_static<AThreadLayout>::value);
// static_assert(is_static<BThreadLayout>::value);
// static_assert(is_static<CThreadLayout>::value);
// CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads
// CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads
// CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M
// CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K
// CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N
// CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K
// CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M
// CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N
// static_assert(is_static<ASmemLayout>::value);
// static_assert(is_static<BSmemLayout>::value);
// static_assert(is_static<CSmemLayout>::value);
// CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
// CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
// CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
// CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
// CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
// CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
// CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
// CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
// CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
// //
// // Full and Tiled Tensors
// //
// // Represent the full tensors
// Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
// Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
// Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// // Get the appropriate blocks for this thread block
// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
// Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
// Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
// Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// // Shared memory buffers
// __shared__ TA smemA[cosize_v<ASmemLayout>];
// __shared__ TB smemB[cosize_v<BSmemLayout>];
// Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
// Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
// //
// // Partition the copying of A and B tiles across the threads
// //
// // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles
// Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
// Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
// Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
// Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
// CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M
// CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K
// CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N
// CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K
// //
// // Define A/B partitioning and C accumulators
// //
// // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC
// // Partition sA (M,K) by the rows of tC
// Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// // Partition sB (N,K) by the cols of tC
// Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// // Partition gC (M,N) by the tile of tC
// Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// // Allocate the accumulators -- same shape/layout as the partitioned data
// Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)
// CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M
// CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N
// CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N
// CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K
// // Clear the accumulators
// clear(tCrC);
// #if 0
// if(thread0()) {
// print(" mA : "); print( mA); print("\n");
// print(" gA : "); print( gA); print("\n");
// print(" sA : "); print( sA); print("\n");
// print("tAgA : "); print(tAgA); print("\n");
// print("tAsA : "); print(tAsA); print("\n");
// }
// #endif
// #if 0
// if(thread0()) {
// print(" mB : "); print( mB); print("\n");
// print(" gB : "); print( gB); print("\n");
// print(" sB : "); print( sB); print("\n");
// print("tBgB : "); print(tBgB); print("\n");
// print("tBsB : "); print(tBsB); print("\n");
// }
// #endif
// #if 0
// if(thread0()) {
// print(" mC : "); print( mC); print("\n");
// print(" gC : "); print( gC); print("\n");
// print("tCsA : "); print(tCsA); print("\n");
// print("tCsB : "); print(tCsB); print("\n");
// print("tCgC : "); print(tCgC); print("\n");
// print("tCrC : "); print(tCrC); print("\n");
// }
// #endif
// #if 1
// // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
// // and then computes on those tiles.
// // copy(.) operates on the global and shared memory via the tA|tB partitioning
// // gemm(.) operates on the shared and register memory via the tC partitioning
// auto K_TILE_MAX = size<2>(tAgA);
// for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
// {
// // Copy gmem to smem with tA|tB thread-partitioned tensors
// copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)
// copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)
// // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
// // Tensor tAgAk = tAgA(_,_,k_tile);
// // CUTE_UNROLL
// // for (int i = 0; i < size(tAsA); ++i) {
// // tAsA(i) = tAgAk(i);
// // }
// cp_async_fence(); // Label the end of (potential) cp.async instructions
// cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
// __syncthreads(); // Wait for all threads to write to smem
// // Compute gemm on tC thread-partitioned smem
// gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)
// // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// // CUTE_UNROLL
// // for (int k = 0; k < size<1>(tCsA); ++k) {
// // CUTE_UNROLL
// // for (int m = 0; m < size<0>(tCrC); ++m) {
// // CUTE_UNROLL
// // for (int n = 0; n < size<1>(tCrC); ++n) {
// // tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
// // }
// // }
// // }
// __syncthreads(); // Wait for all threads to read from smem
// }
// #endif
// //
// // Epilogue
// //
// axpby(alpha, tCrC, beta, tCgC);
// // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
// // CUTE_UNROLL
// // for (int i = 0; i < size(tCsA); ++i) {
// // tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
// // }
// }
// // Elementwise addition kernel
// template <typename T>
// __global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) {
// int row = blockIdx.y * blockDim.y + threadIdx.y;
// int col = blockIdx.x * blockDim.x + threadIdx.x;
// if (row < M && col < N) {
// int idx = row * N + col;
// C[idx] += D[idx];
// }
// }
// // Setup params for an NT GEMM
// // Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB
// template <class TA, class TB, class TC,
// class Alpha, class Beta>
// void
// gemm_nt(int m, int n, int k,
// Alpha alpha,
// TA const* A, int ldA,
// TB const* B, int ldB,
// Beta beta,
// TC * C, int ldC,
// cudaStream_t stream = 0)
// {
// using namespace cute;
// // Define shapes (dynamic)
// auto M = int(m);
// auto N = int(n);
// auto K = int(k);
// auto prob_shape = make_shape(M, N, K); // (M, N, K)
// // Define NT strides (mixed)
// auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
// auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
// auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// // Define CTA tile sizes (static)
// auto bM = Int<128>{};
// auto bN = Int<128>{};
// auto bK = Int< 8>{};
// auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// // Define the smem layouts (static)
// auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
// auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
// auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// // Define the thread layouts (static)
// auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx
// auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx
// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
// dim3 dimBlock(size(tC));
// dim3 dimGrid(size(ceil_div(M, bM)),
// size(ceil_div(N, bN)));
// gemm_device<<<dimGrid, dimBlock, 0, stream>>>
// (prob_shape, cta_tiler,
// A, dA, sA, tA,
// B, dB, sB, tB,
// C, dC, sC, tC,
// alpha, beta);
// }
// // Setup params for a TN GEMM
// // Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB
// template <class TA, class TB, class TC,
// class Alpha, class Beta>
// void
// gemm_tn(int m, int n, int k,
// Alpha alpha,
// TA const* A, int ldA,
// TB const* B, int ldB,
// Beta beta,
// TC * C, int ldC,
// cudaStream_t stream = 0)
// {
// using namespace cute;
// // Define shapes (dynamic)
// auto M = int(m);
// auto N = int(n);
// auto K = int(k);
// auto prob_shape = make_shape(M, N, K); // (M, N, K)
// // Define TN strides (mixed)
// auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
// auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
// auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// // Define CTA tile sizes (static)
// auto bM = Int<128>{};
// auto bN = Int<128>{};
// auto bK = Int< 8>{};
// auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// // Define the smem layouts (static)
// auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
// auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
// auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// // Define the thread layouts (static)
// auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
// auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major
// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major
// dim3 dimBlock(size(tC));
// dim3 dimGrid(size(ceil_div(M, bM)),
// size(ceil_div(N, bN)));
// gemm_device<<<dimGrid, dimBlock, 0, stream>>>
// (prob_shape, cta_tiler,
// A, dA, sA, tA,
// B, dB, sB, tB,
// C, dC, sC, tC,
// alpha, beta);
// }
// template <class TA, class TB, class TC, class TD,
// class Alpha, class Beta>
// void
// gemm(char transA, char transB, int m, int n, int k,
// Alpha alpha,
// TA const* A, int ldA,
// TB const* B, int ldB,
// Beta beta,
// TC * C, int ldC,
// TD const* D, // New parameter for the D matrix
// cudaStream_t stream = 0)
// {
// if (transA == 'N' && transB == 'T') {
// gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
// } else if (transA == 'T' && transB == 'N') {
// gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
// } else {
// assert(false && "Not implemented");
// return;
// }
// // Launch elementwise addition kernel
// dim3 blockDim(32, 32);
// dim3 gridDim((n + blockDim.x - 1) / blockDim.x,
// (m + blockDim.y - 1) / blockDim.y);
// elementwiseAddKernel<<<gridDim, blockDim, 0, stream>>>(C, D, m, n);
// }
// int main(int argc, char** argv)
// {
// int m = 128;
// if (argc >= 2)
// sscanf(argv[1], "%d", &m);
// int n = 128;
// if (argc >= 3)
// sscanf(argv[2], "%d", &n);
// int k = 128;
// if (argc >= 4)
// sscanf(argv[3], "%d", &k);
// char transA = 'N';
// if (argc >= 5)
// sscanf(argv[4], "%c", &transA);
// char transB = 'T';
// if (argc >= 6)
// sscanf(argv[5], "%c", &transB);
// using TA = float;
// using TB = float;
// using TC = float;
// using TI = float;
// TI alpha = 1.0;
// TI beta = 0.0;
// std::cout << "M = " << m << std::endl;
// std::cout << "N = " << n << std::endl;
// std::cout << "K = " << k << std::endl;
// std::cout << "C = A^" << transA << " B^" << transB << std::endl;
// cute::device_init(0);
// thrust::host_vector<TA> h_A(m*k);
// thrust::host_vector<TB> h_B(n*k);
// thrust::host_vector<TC> h_C(m*n);
// for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
// for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
// for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
// thrust::device_vector<TA> d_A = h_A;
// thrust::device_vector<TB> d_B = h_B;
// thrust::device_vector<TC> d_C = h_C;
// // Add D matrix
// thrust::host_vector<TC> h_D(m*n);
// for (int j = 0; j < m*n; ++j) h_D[j] = static_cast<TC>( 2*(rand() / double(RAND_MAX)) - 1 );
// thrust::device_vector<TC> d_D = h_D;
// double gflops = (2.0*m*n*k) * 1e-9;
// const int timing_iterations = 100;
// GPU_Clock timer;
// int ldA = 0, ldB = 0, ldC = m;
// if (transA == 'N') {
// ldA = m;
// } else if (transA == 'T') {
// ldA = k;
// } else {
// assert(false);
// }
// if (transB == 'N') {
// ldB = k;
// } else if (transB == 'T') {
// ldB = n;
// } else {
// assert(false);
// }
// // Run once
// d_C = h_C;
// gemm(transA, transB, m, n, k,
// alpha,
// d_A.data().get(), ldA,
// d_B.data().get(), ldB,
// beta,
// d_C.data().get(), ldC,
// d_D.data().get()); // Add D matrix to the GEMM call
// CUTE_CHECK_LAST();
// // thrust::host_vector<TC> cute_result = d_C;
// // Timing iterations
// // timer.start();
// // for (int i = 0; i < timing_iterations; ++i) {
// // gemm(transA, transB, m, n, k,
// // alpha,
// // d_A.data().get(), ldA,
// // d_B.data().get(), ldB,
// // beta,
// // d_C.data().get(), ldC,
// // d_D.data().get()); // Add D matrix to the GEMM call
// // }
// // double cute_time = timer.seconds() / timing_iterations;
// // CUTE_CHECK_LAST();
// // printf("CUTE_GEMM + ElementwiseAdd: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
// return 0;
// }
#include <iostream>
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>;
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp
>;
// Elementwise addition kernel
template <typename T>
__global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
int idx = row * N + col;
C[idx] += D[idx];
}
}
// Helper function to print a few elements of a matrix
template <typename T, typename Layout>
void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) {
std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl;
for (int i = 0; i < num_elements; ++i) {
typename Layout::TensorCoord coord;
if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) {
coord = typename Layout::TensorCoord{0, i};
} else {
coord = typename Layout::TensorCoord{i, 0};
}
std::cout << float(tensor.host_ref().at(coord)) << " ";
}
std::cout << std::endl;
}
int run_gemm(int m, int n, int k, float alpha, float beta) {
cutlass::gemm::GemmCoord problem_size(m, n, k);
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_e(problem_size.mn()); // New tensor for elementwise addition
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0);
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0);
cutlass::reference::host::TensorFill(tensor_c.host_view());
cutlass::reference::host::TensorFill(tensor_d.host_view());
cutlass::reference::host::TensorFillRandomUniform( // Fill tensor_e with random values
tensor_e.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
0);
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_e.sync_device();
typename Gemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{ElementAccumulator(alpha), ElementAccumulator(beta)}
};
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl;
return -1;
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run CUTLASS Gemm." << std::endl;
return -1;
}
// Launch elementwise addition kernel
dim3 block(32, 32);
dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y);
elementwiseAddKernel<<<grid, block>>>(
reinterpret_cast<cutlass::half_t*>(tensor_d.device_data()),
reinterpret_cast<cutlass::half_t*>(tensor_e.device_data()),
m, n
);
// Print some elements of the input matrices
print_matrix_elements(tensor_a, "Matrix A");
print_matrix_elements(tensor_b, "Matrix B");
print_matrix_elements(tensor_e, "Matrix E (for elementwise addition)");
cudaDeviceSynchronize();
tensor_d.sync_host();
// Print some elements of the output matrix
print_matrix_elements(tensor_d, "Result Matrix D (after GEMM and elementwise addition)");
// cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn());
// cutlass::reference::host::Gemm<
// ElementInputA, LayoutInputA,
// ElementInputB, LayoutInputB,
// ElementOutput, LayoutOutput,
// ElementAccumulator, ElementAccumulator
// > reference_gemm;
// reference_gemm(
// problem_size,
// alpha,
// tensor_a.host_ref(),
// tensor_b.host_ref(),
// beta,
// tensor_c.host_ref(),
// reference_d.host_ref()
// );
// // Perform elementwise addition on CPU for reference
// for (int i = 0; i < m * n; ++i) {
// reference_d.host_ref().data()[i] += tensor_e.host_ref().data()[i];
// }
// bool passed = cutlass::reference::host::TensorEquals(
// reference_d.host_view(),
// tensor_d.host_view()
// );
// std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl;
// return passed ? 0 : -1;
return 0;
}
int main() {
// int m = 8192;
// int n = 8192;
// int k = 8192;
// int m = 4096;
// int n = 4096;
// int k = 4096;
// int m = 2048;
// int n = 2048;
// int k = 2048;
// int m = 1024;
// int n = 1024;
// int k = 1024;
// int m = 512;
// int n = 512;
// int k = 512;
// int m = 256;
// int n = 256;
// int k = 256;
int m = 128;
int n = 128;
int k = 128;
float alpha = 1.0f;
float beta = 0.0f;
return run_gemm(m, n, k, alpha, beta);
}Editor is loading...
Leave a Comment