wmma_example
unknown
c_cpp
a month ago
1.9 kB
1
Indexable
Never
__global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, float alpha, float beta) { // Leading dimensions. Packed with no transpositions. int lda = M; int ldb = K; int ldc = M; // Tile using a 2D grid int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize; int warpN = (blockIdx.y * blockDim.y + threadIdx.y); // Declare the fragments wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag; wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag; wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag; wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag; wmma::fill_fragment(acc_frag, 0.0f); // Loop over k for (int i = 0; i < K; i += WMMA_K) { int aRow = warpM * WMMA_M; int aCol = i; int bRow = i; int bCol = warpN * WMMA_N; printf("Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n", threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol); // Bounds checking if (aRow < M && aCol < K && bRow < K && bCol < N) { // Load the inputs printf("Thread (%d, %d) loading A from address %p\n", threadIdx.x, threadIdx.y, (void*)(a + aRow + aCol * lda)); printf("Thread (%d, %d) loading B from address %p\n", threadIdx.x, threadIdx.y, (void*)(b + bRow + bCol * ldb)); printf("INSIDE--Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n", threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol); wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda); wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb); // Perform the matrix multiplication wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); } }
Leave a Comment