gemm
unknown
python
2 years ago
3.9 kB
5
Indexable
import torch def wrap(A_Fragment,B_Fragment ): thread_x = 4 thread_y = 32//thread_x ##(4*8) _accumulator = torch.zeros(A_Fragment.shape[0],B_Fragment.shape[1],dtype=torch.int64) Mr = A_Fragment.shape[0]//thread_x Nr = B_Fragment.shape[1]//thread_y ''' each warp has 32 threads, CUTLASS organizes the threads within the same warp in a 4 × 8 or 8 × 4 fashion such that mW /mR = thread_x = 4, nW /nR = thread_y = 8, or mW /mR = thread_x = 8, nW /nR = thread_y = 4 ''' ##An wrap with 32 threads will performance outer product betweeen A Fragment and B Fragment for i in range(thread_x): for j in range(thread_y): ###Load to register file thread_a = torch.squeeze(A_Fragment[i*Mr:(i+1)*Mr,:]) thread_b = torch.squeeze(B_Fragment[:,j*Nr:(j+1)*Nr]) ## shape thread_a =(Mw/,) # shape thread_b =(8,) ## Thread level performance accumulates an Mr × Nr outer product. _accumulator[i*Mr:(i+1)*Mr,j*Nr:(j+1)*Nr] += torch.outer(thread_a,thread_b) return _accumulator def thread_block(matrix_tile_a,matrix_tile_b): wrap_size_x = 4 wrap_size_y = 2 thread_block_accumulator = torch.zeros(matrix_tile_a.shape[0],matrix_tile_b.shape[1],dtype=torch.int64) Mw = matrix_tile_a.shape[0]//wrap_size_x Nw = matrix_tile_b.shape[1]//wrap_size_y for i in range(wrap_size_x): for j in range(wrap_size_y): matrix_wrap_a = matrix_tile_a[i*Mw:(i+1)*Mw,:] matrix_wrap_b = matrix_tile_b[:,j*Nw:(j+1)*Nw] wrap_accumulator = torch.zeros(Mw,Nw,dtype=torch.int64) #At and step k in loop, perform a matrix product, update output to matrix C_accumulator. for k in range(matrix_tile_a.shape[1]): #CAN'T BE PARALLEL A_Fragment= matrix_wrap_a[:,k:k+1] B_Fragment= matrix_wrap_b[k:k+1,:] ### Call a wrap to perform outer products. ### When wrap is performaning outer products, the next A_Fragment,B_Fragment is loaded to memory. wrap_accumulator += wrap(A_Fragment,B_Fragment) ##_syncthreads thread_block_accumulator[i*Mw:(i+1)*Mw,j*Nw:(j+1)*Nw] = wrap_accumulator #wrap_accumulator +=thread_accumulator return thread_block_accumulator # inital two matrix a = torch.randint(1,100000,(512,32),dtype=torch.int64) b = torch.randint(1,100000,(32,512),dtype=torch.int64) tile_size_x = 4 tile_size_y = 4 result = torch.zeros((a.shape[0],b.shape[1]),dtype=torch.int64) Ks = a.shape[1]//tile_size_x ##width of tile matrix A and height of tile matrix B Ms = a.shape[0]//tile_size_x ##height of tile matrix A Ns = b.shape[1]//tile_size_y ##width of tile matrix B for i in range(tile_size_x): for j in range(tile_size_y): K_accumulator = torch.zeros(Ms,Ns,dtype=torch.int64) ### At and step k in loop, perform a matrix product, update output to matrix C. for k in range(tile_size_x): ### CAN'T BE PARALLEL ### Load matrix_tile_a,matrix_tile_b to shared memory: matrix_tile_a = a[i*Ms:(i+1)*Ms,k*Ks:(k+1)*Ks] matrix_tile_b = b[k*Ks:(k+1)*Ks,j*Ns:(j+1)*Ns] #### Call threads block to perform a matrix product of matrix_tile_a and matrix_tile_b, output is a part of matrix C. K_accumulator += thread_block(matrix_tile_a,matrix_tile_b) #print(matrix_tile_a.shape,matrix_tile_b.shape) #K_accumulator += torch.matmul(matrix_tile_a,matrix_tile_b) #print(K_accumulator) result[i*Ms:(i+1)*Ms,j*Ns:(j+1)*Ns] = K_accumulator #for i in range(tile_size_x*tile_size_y): result_truth = torch.matmul(a,b) print(result_truth-result)
Editor is loading...