tc kernel
unknown
c_cpp
a month ago
12 kB
3
Indexable
Never
/* Copyright (c) 1993-2017, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ // #include <cublas_v2.h> // #include <curand.h> #include <stdio.h> // Define some error checking macros. #define cudaErrCheck(stat) \ { cudaErrCheck_((stat), __FILE__, __LINE__); } void cudaErrCheck_(cudaError_t stat, const char *file, int line) { if (stat != cudaSuccess) { fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line); printf("CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line); } } // #define cublasErrCheck(stat) \ // { cublasErrCheck_((stat), __FILE__, __LINE__); } // void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) { // if (stat != CUBLAS_STATUS_SUCCESS) { // fprintf(stderr, "cuBLAS Error: %d %s %d\n", stat, file, line); // } // } // #define curandErrCheck(stat) \ // { curandErrCheck_((stat), __FILE__, __LINE__); } // void curandErrCheck_(curandStatus_t stat, const char *file, int line) { // if (stat != CURAND_STATUS_SUCCESS) { // fprintf(stderr, "cuRand Error: %d %s %d\n", stat, file, line); // } // } #include <mma.h> using namespace nvcuda; // Must be multiples of 16 for wmma code to work #define MATRIX_M 64 #define MATRIX_N 64 #define MATRIX_K 64 // The only dimensions currently supported by WMMA const int WMMA_M = 16; const int WMMA_N = 16; const int WMMA_K = 16; // Performs an MxNxK GEMM (C=alpha*A*B + beta*C) assuming: // 1) Matrices are packed in memory. // 2) M, N and K are multiples of 16. // 3) Neither A nor B are transposed. // Note: This is NOT a high performance example but is for demonstration // purposes only // For a high performance code please use the GEMM provided in cuBLAS. __global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, float alpha, float beta) { // Leading dimensions. Packed with no transpositions. int lda = M; int ldb = K; int ldc = M; // Tile using a 2D grid int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize; int warpN = (blockIdx.y * blockDim.y + threadIdx.y); // Declare the fragments wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag; wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag; wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag; wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag; wmma::fill_fragment(acc_frag, 0.0f); // Loop over k for (int i = 0; i < K; i += WMMA_K) { int aRow = warpM * WMMA_M; int aCol = i; int bRow = i; int bCol = warpN * WMMA_N; // printf("Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n", // threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol); // Bounds checking if (aRow < M && aCol < K && bRow < K && bCol < N) { // Load the inputs // printf("Thread (%d, %d) loading A from address %p\n", // threadIdx.x, threadIdx.y, (half*)(a + aRow + aCol * lda)); wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda); wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb); // Perform the matrix multiplication wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); } } // unrool // int aRow[2]; // int aCol[2]; // int bRow[2]; // int bCol[2]; // aRow[0] = warpM * WMMA_M; // aCol[0] = 0; // bRow[0] = 0; // bCol[0] = warpN * WMMA_N; // aRow[1] = warpM * WMMA_M; // aCol[1] = 16; // bRow[1] = 16; // bCol[1] = warpN * WMMA_N; // // // Bounds checking // // if (aRow[i] < M && aCol[i] < K && bRow[i] < K && bCol[i] < N) { // // Load the inputs // wmma::load_matrix_sync(a_frag, a + aRow[0] + aCol[0] * lda, lda); // wmma::load_matrix_sync(b_frag, b + bRow[0] + bCol[0] * ldb, ldb); // // Perform the matrix multiplication // wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); // // } // // // Bounds checking // // if (aRow[i] < M && aCol[i] < K && bRow[i] < K && bCol[i] < N) { // // Load the inputs // wmma::load_matrix_sync(a_frag, a + aRow[1] + aCol[1] * lda, lda); // wmma::load_matrix_sync(b_frag, b + bRow[1] + bCol[1] * ldb, ldb); // // Perform the matrix multiplication // wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); // // } // Load in the current value of c, scale it by beta, and add this our result // scaled by alpha int cRow = warpM * WMMA_M; int cCol = warpN * WMMA_N; if (cRow < M && cCol < N) { wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major); #pragma unroll for (int i = 0; i < c_frag.num_elements; i++) { c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i]; } // Store the output wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major); } } // __global__ void convertFp32ToFp16(half *out, float *in, int n) { // int idx = blockDim.x * blockIdx.x + threadIdx.x; // if (idx < n) { // out[idx] = in[idx]; // } // } __global__ void convertFp32ToFp16(half *out, float *in, int n) { int idx = blockDim.x * blockIdx.x + threadIdx.x; if (idx < n) { out[idx] = __float2half_rn(in[idx]); } } int main(int argc, char *argv[]) { float *a_fp32; float *b_fp32; half *a_fp16; half *b_fp16; float *c_wmma; float *c_host_wmma; //! Host memory allocation float *h_a_fp32 = (float *)malloc(MATRIX_M * MATRIX_K * sizeof(float)); float *h_b_fp32 = (float *)malloc(MATRIX_K * MATRIX_N * sizeof(float)); float *h_c = (float *)malloc(MATRIX_M * MATRIX_N * sizeof(float)); //! Host memory for the result c_host_wmma = (float *)malloc(MATRIX_M * MATRIX_N * sizeof(float)); if (!c_host_wmma) { fprintf(stderr, "Host memory allocation for result failed\n"); exit(1); } // Check for successful host memory allocation if (!h_a_fp32 || !h_b_fp32 || !h_c || !c_host_wmma) { printf("Host memory allocation failed\n"); return -1; } //! Fill host memory with values for (int i = 0; i < MATRIX_M * MATRIX_K; ++i) { h_a_fp32[i] = float(i % 255 - 127) / 127; } for (int i = 0; i < MATRIX_K * MATRIX_N ; ++i) { h_b_fp32[i] = float(i % 255 - 127) / 127; } for (int i = 0; i < MATRIX_M * MATRIX_N ; ++i) { h_c[i] = float(i % 255 - 127) / 127; } //! Device memory allocation cudaErrCheck( cudaMalloc((void **)&a_fp32, MATRIX_M * MATRIX_K * sizeof(float))); cudaErrCheck( cudaMalloc((void **)&b_fp32, MATRIX_K * MATRIX_N * sizeof(float))); cudaErrCheck( cudaMalloc((void **)&a_fp16, MATRIX_M * MATRIX_K * sizeof(half))); cudaErrCheck( cudaMalloc((void **)&b_fp16, MATRIX_K * MATRIX_N * sizeof(half))); cudaErrCheck(cudaMalloc((void **)&c_wmma, MATRIX_M * MATRIX_N * sizeof(float))); //! Copy host data to device cudaErrCheck(cudaMemcpy(a_fp32, h_a_fp32, MATRIX_M * MATRIX_K * sizeof(float), cudaMemcpyHostToDevice)); cudaErrCheck(cudaMemcpy(b_fp32, h_b_fp32, MATRIX_K * MATRIX_N * sizeof(float), cudaMemcpyHostToDevice)); cudaErrCheck(cudaMemcpy(c_wmma, h_c, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyHostToDevice)); // curandErrCheck(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); // curandErrCheck(curandSetPseudoRandomGeneratorSeed(gen, 1337ULL)); // curandErrCheck(curandGenerateUniform(gen, a_fp32, MATRIX_M * MATRIX_K)); // curandErrCheck(curandGenerateUniform(gen, b_fp32, MATRIX_K * MATRIX_N)); //! Fill in the numbers //* cannot directly access device memory from host code // for (int i = 0; i < MATRIX_M * MATRIX_K; ++i) { // a_fp16[i] = __float2half_rn(a_fp32[i]); // } // for (int i = 0; i < MATRIX_K * MATRIX_N ; ++i) { // b_fp16[i] = __float2half_rn(b_fp32[i]); // } //* 4 warps // convertFp32ToFp16<<<(MATRIX_M * MATRIX_K + 127) /128 , 128>>>( // a_fp16, a_fp32, MATRIX_M * MATRIX_K); // convertFp32ToFp16<<<(MATRIX_K * MATRIX_N + 128) / 128, 128>>>( // b_fp16, b_fp32, MATRIX_K * MATRIX_N); //* 1 warp convertFp32ToFp16<<<(MATRIX_M * MATRIX_K + 31) / 32, 32>>>( a_fp16, a_fp32, MATRIX_M * MATRIX_K); convertFp32ToFp16<<<(MATRIX_K * MATRIX_N + 31) / 32, 32>>>( b_fp16, b_fp32, MATRIX_K * MATRIX_N); //! For c matrix float alpha = 2.0f; float beta = 2.0f; printf("\nM = %d, N = %d, K = %d. alpha = %f, beta = %f\n\n", MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta); // First: using WMMA dim3 gridDim; dim3 blockDim; // blockDim.x must be a multple of warpSize // 128x4 means we have 16 warps and a block computes a 64x64 output tile //* 128*4 = 512, 512/32 = 16 warps, 1 warp can compute 16x16 output tile, so 4*4 warps can compute 64x64 output tile blockDim.x = 32; blockDim.y = 1; gridDim.x = (MATRIX_M + (WMMA_M * blockDim.x / 32 - 1)) / (WMMA_M * blockDim.x / 32); gridDim.y = (MATRIX_N + WMMA_N * blockDim.y - 1) / (WMMA_N * blockDim.y); printf("gridDim.x = %d, gridDim.y = %d\n", gridDim.x, gridDim.y); printf("Running with wmma...\n"); wmma_example<<<gridDim, blockDim>>>(a_fp16, b_fp16, c_wmma, MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta); // Error checking cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA Error: %s\n", cudaGetErrorString(err)); } printf("\nChecking results...\n"); cudaErrCheck(cudaMemcpy(c_host_wmma, c_wmma, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToHost)); // Free host memory free(h_a_fp32); free(h_b_fp32); free(h_c); free(c_host_wmma); // Free device memory cudaErrCheck(cudaFree(a_fp32)); cudaErrCheck(cudaFree(b_fp32)); cudaErrCheck(cudaFree(a_fp16)); cudaErrCheck(cudaFree(b_fp16)); cudaErrCheck(cudaFree(c_wmma)); cudaErrCheck(cudaDeviceReset()); return 0; }
Leave a Comment