Untitled
unknown
plain_text
3 years ago
1.1 kB
3
Indexable
class ViTBlock(nn.Module): # Vision Transformer https://arxiv.org/abs/2010.11929 def __init__(self, c1, c2, patch_size, num_heads, num_layers): super().__init__() self.layers = nn.ModuleList([]) dim_head = 64 mlp_dim = 64 dropout = 0.1 self.patch_size = patch_size for _ in range(num_layers): self.layers.append(nn.ModuleList([ PreNorm(c1, MultiHeadSelfAttention(c1, num_heads, dim_head)), PreNorm(c1, FFN(c1, mlp_dim, dropout)) ])) self.conv_out = Conv(c1, c2, k = 1, act = False) def forward(self, x): _, _, h, w = x.shape x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph = self.patch_size, pw = self.patch_size) for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h = h // self.patch_size, w = w // self.patch_size, ph = self.patch_size, pw = self.patch_size) return self.conv_out(x)
Editor is loading...