1010_matmul_add
user_3093867
c_cpp
5 months ago
27 kB
2
Indexable
// /*************************************************************************************************** // * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // * SPDX-License-Identifier: BSD-3-Clause // * // * Redistribution and use in source and binary forms, with or without // * modification, are permitted provided that the following conditions are met: // * // * 1. Redistributions of source code must retain the above copyright notice, this // * list of conditions and the following disclaimer. // * // * 2. Redistributions in binary form must reproduce the above copyright notice, // * this list of conditions and the following disclaimer in the documentation // * and/or other materials provided with the distribution. // * // * 3. Neither the name of the copyright holder nor the names of its // * contributors may be used to endorse or promote products derived from // * this software without specific prior written permission. // * // * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" // * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE // * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE // * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL // * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR // * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER // * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, // * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // * // **************************************************************************************************/ // #include <cstdlib> // #include <cstdio> // #include <cassert> // #include <thrust/host_vector.h> // #include <thrust/device_vector.h> // #include <cute/tensor.hpp> // #include "cutlass/util/print_error.hpp" // #include "cutlass/util/GPU_Clock.hpp" // #include "cutlass/util/helper_cuda.hpp" // template <class ProblemShape, class CtaTiler, // class TA, class AStride, class ASmemLayout, class AThreadLayout, // class TB, class BStride, class BSmemLayout, class BThreadLayout, // class TC, class CStride, class CSmemLayout, class CThreadLayout, // class Alpha, class Beta> // __global__ static // __launch_bounds__(decltype(size(CThreadLayout{}))::value) // void // gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA, // TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB, // TC * C, CStride dC, CSmemLayout , CThreadLayout tC, // Alpha alpha, Beta beta) // { // using namespace cute; // // Preconditions // CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) // CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) // static_assert(is_static<AThreadLayout>::value); // static_assert(is_static<BThreadLayout>::value); // static_assert(is_static<CThreadLayout>::value); // CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads // CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads // CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M // CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K // CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N // CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K // CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M // CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N // static_assert(is_static<ASmemLayout>::value); // static_assert(is_static<BSmemLayout>::value); // static_assert(is_static<CSmemLayout>::value); // CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M // CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M // CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N // CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N // CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K // CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K // CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK // CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK // CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN // // // // Full and Tiled Tensors // // // // Represent the full tensors // Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) // Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) // Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) // // Get the appropriate blocks for this thread block // auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) // Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) // Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) // Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) // // Shared memory buffers // __shared__ TA smemA[cosize_v<ASmemLayout>]; // __shared__ TB smemB[cosize_v<BSmemLayout>]; // Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) // Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) // // // // Partition the copying of A and B tiles across the threads // // // // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles // Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) // Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) // Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) // Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) // CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M // CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K // CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N // CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K // // // // Define A/B partitioning and C accumulators // // // // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC // // Partition sA (M,K) by the rows of tC // Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) // // Partition sB (N,K) by the cols of tC // Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) // // Partition gC (M,N) by the tile of tC // Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) // // Allocate the accumulators -- same shape/layout as the partitioned data // Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N) // CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M // CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M // CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N // CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N // CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K // // Clear the accumulators // clear(tCrC); // #if 0 // if(thread0()) { // print(" mA : "); print( mA); print("\n"); // print(" gA : "); print( gA); print("\n"); // print(" sA : "); print( sA); print("\n"); // print("tAgA : "); print(tAgA); print("\n"); // print("tAsA : "); print(tAsA); print("\n"); // } // #endif // #if 0 // if(thread0()) { // print(" mB : "); print( mB); print("\n"); // print(" gB : "); print( gB); print("\n"); // print(" sB : "); print( sB); print("\n"); // print("tBgB : "); print(tBgB); print("\n"); // print("tBsB : "); print(tBsB); print("\n"); // } // #endif // #if 0 // if(thread0()) { // print(" mC : "); print( mC); print("\n"); // print(" gC : "); print( gC); print("\n"); // print("tCsA : "); print(tCsA); print("\n"); // print("tCsB : "); print(tCsB); print("\n"); // print("tCgC : "); print(tCgC); print("\n"); // print("tCrC : "); print(tCrC); print("\n"); // } // #endif // #if 1 // // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory, // // and then computes on those tiles. // // copy(.) operates on the global and shared memory via the tA|tB partitioning // // gemm(.) operates on the shared and register memory via the tC partitioning // auto K_TILE_MAX = size<2>(tAgA); // for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) // { // // Copy gmem to smem with tA|tB thread-partitioned tensors // copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K) // copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K) // // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to // // Tensor tAgAk = tAgA(_,_,k_tile); // // CUTE_UNROLL // // for (int i = 0; i < size(tAsA); ++i) { // // tAsA(i) = tAgAk(i); // // } // cp_async_fence(); // Label the end of (potential) cp.async instructions // cp_async_wait<0>(); // Sync on all (potential) cp.async instructions // __syncthreads(); // Wait for all threads to write to smem // // Compute gemm on tC thread-partitioned smem // gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K) // // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to // // CUTE_UNROLL // // for (int k = 0; k < size<1>(tCsA); ++k) { // // CUTE_UNROLL // // for (int m = 0; m < size<0>(tCrC); ++m) { // // CUTE_UNROLL // // for (int n = 0; n < size<1>(tCrC); ++n) { // // tCrC(m,n) += tCsA(m,k) * tCsB(n,k); // // } // // } // // } // __syncthreads(); // Wait for all threads to read from smem // } // #endif // // // // Epilogue // // // axpby(alpha, tCrC, beta, tCgC); // // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to // // CUTE_UNROLL // // for (int i = 0; i < size(tCsA); ++i) { // // tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); // // } // } // // Elementwise addition kernel // template <typename T> // __global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) { // int row = blockIdx.y * blockDim.y + threadIdx.y; // int col = blockIdx.x * blockDim.x + threadIdx.x; // if (row < M && col < N) { // int idx = row * N + col; // C[idx] += D[idx]; // } // } // // Setup params for an NT GEMM // // Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB // template <class TA, class TB, class TC, // class Alpha, class Beta> // void // gemm_nt(int m, int n, int k, // Alpha alpha, // TA const* A, int ldA, // TB const* B, int ldB, // Beta beta, // TC * C, int ldC, // cudaStream_t stream = 0) // { // using namespace cute; // // Define shapes (dynamic) // auto M = int(m); // auto N = int(n); // auto K = int(k); // auto prob_shape = make_shape(M, N, K); // (M, N, K) // // Define NT strides (mixed) // auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) // auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) // auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // // Define CTA tile sizes (static) // auto bM = Int<128>{}; // auto bN = Int<128>{}; // auto bK = Int< 8>{}; // auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) // // Define the smem layouts (static) // auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major // auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major // auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major // // Define the thread layouts (static) // auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx // auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx // auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx // dim3 dimBlock(size(tC)); // dim3 dimGrid(size(ceil_div(M, bM)), // size(ceil_div(N, bN))); // gemm_device<<<dimGrid, dimBlock, 0, stream>>> // (prob_shape, cta_tiler, // A, dA, sA, tA, // B, dB, sB, tB, // C, dC, sC, tC, // alpha, beta); // } // // Setup params for a TN GEMM // // Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB // template <class TA, class TB, class TC, // class Alpha, class Beta> // void // gemm_tn(int m, int n, int k, // Alpha alpha, // TA const* A, int ldA, // TB const* B, int ldB, // Beta beta, // TC * C, int ldC, // cudaStream_t stream = 0) // { // using namespace cute; // // Define shapes (dynamic) // auto M = int(m); // auto N = int(n); // auto K = int(k); // auto prob_shape = make_shape(M, N, K); // (M, N, K) // // Define TN strides (mixed) // auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) // auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) // auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // // Define CTA tile sizes (static) // auto bM = Int<128>{}; // auto bN = Int<128>{}; // auto bK = Int< 8>{}; // auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) // // Define the smem layouts (static) // auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major // auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major // auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major // // Define the thread layouts (static) // auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major // auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major // auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major // dim3 dimBlock(size(tC)); // dim3 dimGrid(size(ceil_div(M, bM)), // size(ceil_div(N, bN))); // gemm_device<<<dimGrid, dimBlock, 0, stream>>> // (prob_shape, cta_tiler, // A, dA, sA, tA, // B, dB, sB, tB, // C, dC, sC, tC, // alpha, beta); // } // template <class TA, class TB, class TC, class TD, // class Alpha, class Beta> // void // gemm(char transA, char transB, int m, int n, int k, // Alpha alpha, // TA const* A, int ldA, // TB const* B, int ldB, // Beta beta, // TC * C, int ldC, // TD const* D, // New parameter for the D matrix // cudaStream_t stream = 0) // { // if (transA == 'N' && transB == 'T') { // gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); // } else if (transA == 'T' && transB == 'N') { // gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); // } else { // assert(false && "Not implemented"); // return; // } // // Launch elementwise addition kernel // dim3 blockDim(32, 32); // dim3 gridDim((n + blockDim.x - 1) / blockDim.x, // (m + blockDim.y - 1) / blockDim.y); // elementwiseAddKernel<<<gridDim, blockDim, 0, stream>>>(C, D, m, n); // } // int main(int argc, char** argv) // { // int m = 128; // if (argc >= 2) // sscanf(argv[1], "%d", &m); // int n = 128; // if (argc >= 3) // sscanf(argv[2], "%d", &n); // int k = 128; // if (argc >= 4) // sscanf(argv[3], "%d", &k); // char transA = 'N'; // if (argc >= 5) // sscanf(argv[4], "%c", &transA); // char transB = 'T'; // if (argc >= 6) // sscanf(argv[5], "%c", &transB); // using TA = float; // using TB = float; // using TC = float; // using TI = float; // TI alpha = 1.0; // TI beta = 0.0; // std::cout << "M = " << m << std::endl; // std::cout << "N = " << n << std::endl; // std::cout << "K = " << k << std::endl; // std::cout << "C = A^" << transA << " B^" << transB << std::endl; // cute::device_init(0); // thrust::host_vector<TA> h_A(m*k); // thrust::host_vector<TB> h_B(n*k); // thrust::host_vector<TC> h_C(m*n); // for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 ); // for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 ); // for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1); // thrust::device_vector<TA> d_A = h_A; // thrust::device_vector<TB> d_B = h_B; // thrust::device_vector<TC> d_C = h_C; // // Add D matrix // thrust::host_vector<TC> h_D(m*n); // for (int j = 0; j < m*n; ++j) h_D[j] = static_cast<TC>( 2*(rand() / double(RAND_MAX)) - 1 ); // thrust::device_vector<TC> d_D = h_D; // double gflops = (2.0*m*n*k) * 1e-9; // const int timing_iterations = 100; // GPU_Clock timer; // int ldA = 0, ldB = 0, ldC = m; // if (transA == 'N') { // ldA = m; // } else if (transA == 'T') { // ldA = k; // } else { // assert(false); // } // if (transB == 'N') { // ldB = k; // } else if (transB == 'T') { // ldB = n; // } else { // assert(false); // } // // Run once // d_C = h_C; // gemm(transA, transB, m, n, k, // alpha, // d_A.data().get(), ldA, // d_B.data().get(), ldB, // beta, // d_C.data().get(), ldC, // d_D.data().get()); // Add D matrix to the GEMM call // CUTE_CHECK_LAST(); // // thrust::host_vector<TC> cute_result = d_C; // // Timing iterations // // timer.start(); // // for (int i = 0; i < timing_iterations; ++i) { // // gemm(transA, transB, m, n, k, // // alpha, // // d_A.data().get(), ldA, // // d_B.data().get(), ldB, // // beta, // // d_C.data().get(), ldC, // // d_D.data().get()); // Add D matrix to the GEMM call // // } // // double cute_time = timer.seconds() / timing_iterations; // // CUTE_CHECK_LAST(); // // printf("CUTE_GEMM + ElementwiseAdd: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); // return 0; // } #include <iostream> #include <cuda_runtime.h> #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/gemm.h" using ElementInputA = cutlass::half_t; using ElementInputB = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = float; using LayoutInputA = cutlass::layout::RowMajor; using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::ColumnMajor; using MMAOp = cutlass::arch::OpClassTensorOp; using SmArch = cutlass::arch::Sm75; using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>; using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, ElementAccumulator, ElementAccumulator >; using Gemm = cutlass::gemm::device::Gemm< ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, SmArch, ShapeMMAThreadBlock, ShapeMMAWarp, ShapeMMAOp, EpilogueOp >; // Elementwise addition kernel template <typename T> __global__ void elementwiseAddKernel(T* C, const T* D, int M, int N) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < M && col < N) { int idx = row * N + col; C[idx] += D[idx]; } } // Helper function to print a few elements of a matrix template <typename T, typename Layout> void print_matrix_elements(cutlass::HostTensor<T, Layout>& tensor, const char* name, int num_elements = 5) { std::cout << "First " << num_elements << " elements of " << name << ":" << std::endl; for (int i = 0; i < num_elements; ++i) { typename Layout::TensorCoord coord; if constexpr (std::is_same_v<Layout, cutlass::layout::RowMajor>) { coord = typename Layout::TensorCoord{0, i}; } else { coord = typename Layout::TensorCoord{i, 0}; } std::cout << float(tensor.host_ref().at(coord)) << " "; } std::cout << std::endl; } int run_gemm(int m, int n, int k, float alpha, float beta) { cutlass::gemm::GemmCoord problem_size(m, n, k); cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk()); cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn()); cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn()); cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn()); cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_e(problem_size.mn()); // New tensor for elementwise addition cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, ElementInputA(4), ElementInputA(-4), 0); cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, ElementInputB(4), ElementInputB(-4), 0); cutlass::reference::host::TensorFill(tensor_c.host_view()); cutlass::reference::host::TensorFill(tensor_d.host_view()); cutlass::reference::host::TensorFillRandomUniform( // Fill tensor_e with random values tensor_e.host_view(), 1, ElementOutput(1), ElementOutput(-1), 0); tensor_a.sync_device(); tensor_b.sync_device(); tensor_c.sync_device(); tensor_d.sync_device(); tensor_e.sync_device(); typename Gemm::Arguments arguments{ problem_size, tensor_a.device_ref(), tensor_b.device_ref(), tensor_c.device_ref(), tensor_d.device_ref(), {ElementAccumulator(alpha), ElementAccumulator(beta)} }; Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Gemm." << std::endl; return -1; } status = gemm_op(); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to run CUTLASS Gemm." << std::endl; return -1; } // Launch elementwise addition kernel dim3 block(32, 32); dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); elementwiseAddKernel<<<grid, block>>>( reinterpret_cast<cutlass::half_t*>(tensor_d.device_data()), reinterpret_cast<cutlass::half_t*>(tensor_e.device_data()), m, n ); // Print some elements of the input matrices print_matrix_elements(tensor_a, "Matrix A"); print_matrix_elements(tensor_b, "Matrix B"); print_matrix_elements(tensor_e, "Matrix E (for elementwise addition)"); cudaDeviceSynchronize(); tensor_d.sync_host(); // Print some elements of the output matrix print_matrix_elements(tensor_d, "Result Matrix D (after GEMM and elementwise addition)"); // cutlass::HostTensor<ElementOutput, LayoutOutput> reference_d(problem_size.mn()); // cutlass::reference::host::Gemm< // ElementInputA, LayoutInputA, // ElementInputB, LayoutInputB, // ElementOutput, LayoutOutput, // ElementAccumulator, ElementAccumulator // > reference_gemm; // reference_gemm( // problem_size, // alpha, // tensor_a.host_ref(), // tensor_b.host_ref(), // beta, // tensor_c.host_ref(), // reference_d.host_ref() // ); // // Perform elementwise addition on CPU for reference // for (int i = 0; i < m * n; ++i) { // reference_d.host_ref().data()[i] += tensor_e.host_ref().data()[i]; // } // bool passed = cutlass::reference::host::TensorEquals( // reference_d.host_view(), // tensor_d.host_view() // ); // std::cout << (passed ? "GEMM test passed" : "GEMM test failed") << std::endl; // return passed ? 0 : -1; return 0; } int main() { // int m = 8192; // int n = 8192; // int k = 8192; // int m = 4096; // int n = 4096; // int k = 4096; // int m = 2048; // int n = 2048; // int k = 2048; // int m = 1024; // int n = 1024; // int k = 1024; // int m = 512; // int n = 512; // int k = 512; // int m = 256; // int n = 256; // int k = 256; int m = 128; int n = 128; int k = 128; float alpha = 1.0f; float beta = 0.0f; return run_gemm(m, n, k, alpha, beta); }
Editor is loading...
Leave a Comment