wmma_example
unknown
c_cpp
a year ago
1.9 kB
12
Indexable
__global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K,
float alpha, float beta) {
// Leading dimensions. Packed with no transpositions.
int lda = M;
int ldb = K;
int ldc = M;
// Tile using a 2D grid
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
// Declare the fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(acc_frag, 0.0f);
// Loop over k
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
printf("Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n",
threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol);
// Bounds checking
if (aRow < M && aCol < K && bRow < K && bCol < N) {
// Load the inputs
printf("Thread (%d, %d) loading A from address %p\n",
threadIdx.x, threadIdx.y, (void*)(a + aRow + aCol * lda));
printf("Thread (%d, %d) loading B from address %p\n",
threadIdx.x, threadIdx.y, (void*)(b + bRow + bCol * ldb));
printf("INSIDE--Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n",
threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol);
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
// Perform the matrix multiplication
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}Editor is loading...
Leave a Comment