Untitled

 avatar
unknown
plain_text
2 months ago
2.3 kB
3
Indexable
import numpy as np
from numba import cuda, njit, prange, float32
import timeit


def max_cpu(A, B):
    """
        Straitforward approach that utilize CPU only.
        The loops do not involving any vectorize ops and they are preferable to numba.
    """
    
    C = np.zeros((1000,1000), dtype = np.uint8)
    for i in range(1000):
        for j in range(1000):
            C[i,j] = A[i,j] if A[i,j] > B[i,j] else B[i,j]
    return C


@njit(parallel=True)
def max_numba(A, B):
    """
        Utilize prange option.
    """
    
    C = np.zeros((1000,1000), dtype = np.uint8)
    for i in prange(1000):
        for j in prange(1000):
            C[i,j] = A[i,j] if A[i,j] > B[i,j] else B[i,j]
    return C

def max_gpu(A, B):
    dev_A = cuda.to_device(A)
    dev_B = cuda.to_device(B)
    dev_C = cuda.device_array((1000,1000), np.uint8)
    max_kernel[1000,1000](dev_A,dev_B, dev_C)
    C = dev_C.copy_to_host();
    return C

@cuda.jit
def max_kernel(A, B, C):
    i = cuda.threadIdx.x
    j = cuda.blockIdx.x
    if i < 1000 and j < 1000:
        C[i,j] = A[i,j] if A[i,j] > B[i,j] else B[i,j]


def verify_solution():
    A = np.random.randint(0, 256, (1000, 1000))
    B = np.random.randint(0, 256, (1000, 1000))

    if not np.all(max_cpu(A, B) == np.maximum(A, B)):
        print('[-] max_cpu failed')
        exit(0)
    else:
        print('[+] max_cpu passed')

    if not np.all(max_numba(A, B) == np.maximum(A, B)):
        print('[-] max_numba failed')
        exit(0)
    else:
        print('[+] max_numba passed')

    if not np.all(max_gpu(A, B) == np.maximum(A, B)):
        print('[-] max_gpu failed')
        exit(0)
    else:
        print('[+] max_gpu passed')

    print('[+] All tests passed\n')


# this is the comparison function - keep it as it is.
def max_comparison():
    A = np.random.randint(0, 256, (1000, 1000))
    B = np.random.randint(0, 256, (1000, 1000))

    def timer(f):
        return min(timeit.Timer(lambda: f(A, B)).repeat(3, 20))

    print('[*] CPU:', timer(max_cpu))
    print('[*] Numba:', timer(max_numba))
    print('[*] CUDA:', timer(max_gpu))


if __name__ == '__main__':
    verify_solution()
    max_comparison()
Editor is loading...
Leave a Comment