gemm

mail@pastecode.io avatar
unknown
python
2 years ago
3.9 kB
2
Indexable
Never
import torch 

def wrap(A_Fragment,B_Fragment ):
    thread_x = 4 
    thread_y = 32//thread_x ##(4*8)
    _accumulator =  torch.zeros(A_Fragment.shape[0],B_Fragment.shape[1],dtype=torch.int64)
    Mr = A_Fragment.shape[0]//thread_x
    Nr = B_Fragment.shape[1]//thread_y 
    '''
    each warp has 32 threads,
    CUTLASS organizes the threads within the same warp in a 4 × 8 or 8 × 4 fashion such that mW /mR = thread_x = 4,
    nW /nR = thread_y = 8, or mW /mR = thread_x = 8, nW /nR = thread_y = 4
    '''
    ##An wrap with 32 threads will performance outer product betweeen A Fragment and B Fragment  
    for i in range(thread_x):
        for j in range(thread_y):
            ###Load  to register file 
            thread_a = torch.squeeze(A_Fragment[i*Mr:(i+1)*Mr,:])
            thread_b = torch.squeeze(B_Fragment[:,j*Nr:(j+1)*Nr])
            ## shape thread_a =(Mw/,) # shape thread_b =(8,) 
            ## Thread level performance  accumulates an Mr × Nr outer product.
            _accumulator[i*Mr:(i+1)*Mr,j*Nr:(j+1)*Nr] += torch.outer(thread_a,thread_b)
    return _accumulator
def thread_block(matrix_tile_a,matrix_tile_b):
    wrap_size_x = 4
    wrap_size_y = 2
    thread_block_accumulator =  torch.zeros(matrix_tile_a.shape[0],matrix_tile_b.shape[1],dtype=torch.int64)
    Mw = matrix_tile_a.shape[0]//wrap_size_x
    Nw = matrix_tile_b.shape[1]//wrap_size_y 
    for i in range(wrap_size_x):
        for j in range(wrap_size_y):
            matrix_wrap_a = matrix_tile_a[i*Mw:(i+1)*Mw,:]
            matrix_wrap_b = matrix_tile_b[:,j*Nw:(j+1)*Nw]
            wrap_accumulator =  torch.zeros(Mw,Nw,dtype=torch.int64)
            #At and step k in loop, perform a matrix product, update output to matrix C_accumulator.
            for k in range(matrix_tile_a.shape[1]): #CAN'T BE PARALLEL
                A_Fragment= matrix_wrap_a[:,k:k+1]
                B_Fragment= matrix_wrap_b[k:k+1,:]
                ### Call a wrap to perform outer products.
                ### When wrap is performaning outer products, the next A_Fragment,B_Fragment is loaded to memory.
                wrap_accumulator += wrap(A_Fragment,B_Fragment)
                ##_syncthreads
            thread_block_accumulator[i*Mw:(i+1)*Mw,j*Nw:(j+1)*Nw] = wrap_accumulator
            #wrap_accumulator +=thread_accumulator
    return thread_block_accumulator



# inital two matrix                
a = torch.randint(1,100000,(512,32),dtype=torch.int64)
b = torch.randint(1,100000,(32,512),dtype=torch.int64)
    
tile_size_x = 4
tile_size_y = 4 
result = torch.zeros((a.shape[0],b.shape[1]),dtype=torch.int64)

Ks = a.shape[1]//tile_size_x ##width of tile matrix A and height of tile matrix B 
Ms = a.shape[0]//tile_size_x ##height of tile matrix A
Ns = b.shape[1]//tile_size_y ##width of tile matrix B


for i in range(tile_size_x):
    for j in range(tile_size_y):
        K_accumulator =  torch.zeros(Ms,Ns,dtype=torch.int64)
        ### At and step k in loop, perform a matrix product, update output to matrix C.  
        for k in range(tile_size_x): ### CAN'T BE PARALLEL
            ### Load matrix_tile_a,matrix_tile_b to shared memory:
            matrix_tile_a = a[i*Ms:(i+1)*Ms,k*Ks:(k+1)*Ks]
            matrix_tile_b = b[k*Ks:(k+1)*Ks,j*Ns:(j+1)*Ns]
            #### Call threads block to perform a matrix product of matrix_tile_a and matrix_tile_b, output is a part of matrix C.
            K_accumulator += thread_block(matrix_tile_a,matrix_tile_b)
            #print(matrix_tile_a.shape,matrix_tile_b.shape)
            #K_accumulator += torch.matmul(matrix_tile_a,matrix_tile_b)
        #print(K_accumulator)
        result[i*Ms:(i+1)*Ms,j*Ns:(j+1)*Ns] = K_accumulator
    
#for i in range(tile_size_x*tile_size_y):
               
result_truth = torch.matmul(a,b)
print(result_truth-result)