class ViTBlock(nn.Module):
# Vision Transformer https://arxiv.org/abs/2010.11929
def __init__(self, c1, c2, patch_size, num_heads, num_layers):
super().__init__()
self.layers = nn.ModuleList([])
dim_head = 64
mlp_dim = 64
dropout = 0.1
self.patch_size = patch_size
for _ in range(num_layers):
self.layers.append(nn.ModuleList([
PreNorm(c1, MultiHeadSelfAttention(c1, num_heads, dim_head)),
PreNorm(c1, FFN(c1, mlp_dim, dropout))
]))
self.conv_out = Conv(c1, c2, k = 1, act = False)
def forward(self, x):
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph = self.patch_size, pw = self.patch_size)
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h = h // self.patch_size, w = w // self.patch_size, ph = self.patch_size, pw = self.patch_size)
return self.conv_out(x)