wmma_example

mail@pastecode.io avatar
unknown
c_cpp
5 months ago
1.9 kB
1
Indexable
__global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K,
                             float alpha, float beta) {
  // Leading dimensions. Packed with no transpositions.
  int lda = M;
  int ldb = K;
  int ldc = M;

  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
      a_frag;
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major>
      b_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

  wmma::fill_fragment(acc_frag, 0.0f);

  // Loop over k
  for (int i = 0; i < K; i += WMMA_K) {
    int aRow = warpM * WMMA_M;
    int aCol = i;

    int bRow = i;
    int bCol = warpN * WMMA_N;
    printf("Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n", 
      threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol);

    // Bounds checking
    if (aRow < M && aCol < K && bRow < K && bCol < N) {
      // Load the inputs
      printf("Thread (%d, %d) loading A from address %p\n", 
       threadIdx.x, threadIdx.y, (void*)(a + aRow + aCol * lda));
      printf("Thread (%d, %d) loading B from address %p\n",
        threadIdx.x, threadIdx.y, (void*)(b + bRow + bCol * ldb));
      printf("INSIDE--Thread (%d, %d) working on aRow = %d, aCol = %d, bRow = %d, bCol = %d\n", 
      threadIdx.x, threadIdx.y, aRow, aCol, bRow, bCol);

      wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

      // Perform the matrix multiplication
      wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }
  }
Leave a Comment