Untitled
unknown
python
a year ago
2.1 kB
16
Indexable
def selective_scan_2D_2(self, x, delta, A, B, C, D):
# based on the original pscan implemetation,
# apply twice for horizontal and vertical directions
# x : (BS, L, ED)
# Δ : (BS, L, ED)
# A : (ED, N)
# B : (BS, L, N)
# C : (BS, L, N)
# D : (ED)
# y : (BS, L, ED)
# get sizes
BS, L, ED = x.size()
_,_, N = B.size()
W, H = int(math.sqrt(L)), int(math.sqrt(L)) # assume imgs are square
# STAGE 1: Horizontal pscan
# reshape from BS, H, W -> BS*H, W
x = x.view(BS*H, W, ED)
delta = delta.view(BS*H, W, ED)
B = B.view(BS*H, W, N)
# do horizontal pscan
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (BS*H, W, ED, N) # deltaA is A_bar with ZOH discretization
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (BS*H, W, ED, N) # deltaB is B_bar with Euler discretization
BX = deltaB * (x.unsqueeze(-1)) # (BS*H, W, ED, N)
hs = pscan(deltaA.clone(), BX) # (BS*H, W, ED, N)
# because deltaA will be modified in-place,
# need cloning to reserve deltaA for later vertical pscan
# STAGE 2: Vertical pscan
# reshape BS*H, W -> BS, H, W
deltaA = deltaA.view(BS, H, W, ED, N)
deltaB = deltaB.view(BS, H, W, ED, N)
BX = BX.view(BS, H, W, ED, N)
hs = hs.view(BS, H, W, ED, N)
# reshape BS, H, W -> BS, W, H
deltaA = torch.swapaxes(deltaA, 1, 2)
deltaB = torch.swapaxes(deltaB, 1, 2)
BX = torch.swapaxes(BX, 1, 2)
hs = torch.swapaxes(hs, 1, 2)
# reshape BS, W, H -> BS*W, H
deltaA = x.view(BS*W, H, ED, N)
deltaB = x.view(BS*W, H, ED, N)
BX = BX.view(BS*W, H, ED, N)
hs = hs.view(BS*W, H, ED, N)
# do vertical pscan
hs = pscan(deltaA, BX) # (BS*W, H, ED, N)
# STAGE 3: calculate output
# reshape hs from BS*W, H -> BS, W, H -> BS, H, W -> BS, L
hs = hs.view(BS, W, H, ED, N)
hs = torch.swapaxes(hs, 1, 2)
hs = hs.view(BS, H*W, ED, N)
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
y = y + D * x
return yEditor is loading...
Leave a Comment