Untitled
unknown
plain_text
a year ago
2.1 kB
5
Indexable
def selective_scan_2D_2(self, x, delta, A, B, C, D): # based on the original pscan implemetation, # apply twice for horizontal and vertical directions # x : (BS, L, ED) # Δ : (BS, L, ED) # A : (ED, N) # B : (BS, L, N) # C : (BS, L, N) # D : (ED) # y : (BS, L, ED) # get sizes BS, L, ED = x.size() _,_, N = B.size() W, H = int(math.sqrt(L)), int(math.sqrt(L)) # assume imgs are square # STAGE 1: Horizontal pscan # reshape from BS, H, W -> BS*H, W x = x.view(BS*H, W, ED) delta = delta.view(BS*H, W, ED) B = B.view(BS*H, W, N) # do horizontal pscan deltaA = torch.exp(delta.unsqueeze(-1) * A) # (BS*H, W, ED, N) # deltaA is A_bar with ZOH discretization deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (BS*H, W, ED, N) # deltaB is B_bar with Euler discretization BX = deltaB * (x.unsqueeze(-1)) # (BS*H, W, ED, N) hs = pscan(deltaA.clone(), BX) # (BS*H, W, ED, N) # because deltaA will be modified in-place, # need cloning to reserve deltaA for later vertical pscan # STAGE 2: Vertical pscan # reshape BS*H, W -> BS, H, W deltaA = deltaA.view(BS, H, W, ED, N) deltaB = deltaB.view(BS, H, W, ED, N) BX = BX.view(BS, H, W, ED, N) hs = hs.view(BS, H, W, ED, N) # reshape BS, H, W -> BS, W, H deltaA = torch.swapaxes(deltaA, 1, 2) deltaB = torch.swapaxes(deltaB, 1, 2) BX = torch.swapaxes(BX, 1, 2) hs = torch.swapaxes(hs, 1, 2) # reshape BS, W, H -> BS*W, H deltaA = x.view(BS*W, H, ED, N) deltaB = x.view(BS*W, H, ED, N) BX = BX.view(BS*W, H, ED, N) hs = hs.view(BS*W, H, ED, N) # do vertical pscan hs = pscan(deltaA, BX) # (BS*W, H, ED, N) # STAGE 3: calculate output # reshape hs from BS*W, H -> BS, W, H -> BS, H, W -> BS, L hs = hs.view(BS, W, H, ED, N) hs = torch.swapaxes(hs, 1, 2) hs = hs.view(BS, H*W, ED, N) y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) y = y + D * x return y
Editor is loading...
Leave a Comment