Untitled

 avatar
unknown
plain_text
a year ago
2.1 kB
5
Indexable
def selective_scan_2D_2(self, x, delta, A, B, C, D):  
    # based on the original pscan implemetation, 
    # apply twice for horizontal and vertical directions

    # x : (BS, L, ED)
    # Δ : (BS, L, ED)
    # A : (ED, N)
    # B : (BS, L, N)
    # C : (BS, L, N)
    # D : (ED)
    # y : (BS, L, ED)

    # get sizes
    BS, L, ED = x.size()
    _,_, N = B.size()
    W, H = int(math.sqrt(L)), int(math.sqrt(L))  # assume imgs are square

    # STAGE 1: Horizontal pscan
    # reshape from BS, H, W -> BS*H, W 
    x = x.view(BS*H, W, ED)
    delta = delta.view(BS*H, W, ED)
    B = B.view(BS*H, W, N)
   
    # do horizontal pscan
    deltaA = torch.exp(delta.unsqueeze(-1) * A) # (BS*H, W, ED, N)   # deltaA is A_bar with ZOH discretization
    deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (BS*H, W, ED, N) # deltaB is B_bar with Euler discretization
    BX = deltaB * (x.unsqueeze(-1)) # (BS*H, W, ED, N)
    hs = pscan(deltaA.clone(), BX)  # (BS*H, W, ED, N)
                                    # because deltaA will be modified in-place, 
                                    # need cloning to reserve deltaA for later vertical pscan


    # STAGE 2: Vertical pscan
    # reshape  BS*H, W ->  BS, H, W
    deltaA = deltaA.view(BS, H, W, ED, N)
    deltaB = deltaB.view(BS, H, W, ED, N)
    BX = BX.view(BS, H, W, ED, N)
    hs = hs.view(BS, H, W, ED, N)

    # reshape BS, H, W -> BS, W, H
    deltaA = torch.swapaxes(deltaA, 1, 2)
    deltaB = torch.swapaxes(deltaB, 1, 2)
    BX = torch.swapaxes(BX, 1, 2)
    hs = torch.swapaxes(hs, 1, 2)

    # reshape BS, W, H -> BS*W, H
    deltaA = x.view(BS*W, H, ED, N)
    deltaB = x.view(BS*W, H, ED, N)
    BX = BX.view(BS*W, H, ED, N)
    hs = hs.view(BS*W, H, ED, N)

    # do vertical pscan
    hs = pscan(deltaA, BX)      # (BS*W, H, ED, N)


    # STAGE 3: calculate output

    # reshape hs from BS*W, H -> BS, W, H -> BS, H, W -> BS, L
    hs = hs.view(BS, W, H, ED, N)
    hs = torch.swapaxes(hs, 1, 2)
    hs = hs.view(BS, H*W, ED, N)

    y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
    y = y + D * x

    return y
Editor is loading...
Leave a Comment