Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
1.1 kB
2
Indexable
Never
class ViTBlock(nn.Module):
    # Vision Transformer https://arxiv.org/abs/2010.11929
    def __init__(self, c1, c2, patch_size, num_heads, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([])
        dim_head = 64
        mlp_dim = 64
        dropout = 0.1
        self.patch_size = patch_size
        for _ in range(num_layers):
            self.layers.append(nn.ModuleList([
                PreNorm(c1, MultiHeadSelfAttention(c1, num_heads, dim_head)),
                PreNorm(c1, FFN(c1, mlp_dim, dropout))
            ]))
        self.conv_out = Conv(c1, c2, k = 1, act = False)

    def forward(self, x):
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph = self.patch_size, pw = self.patch_size)
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h = h // self.patch_size, w = w // self.patch_size, ph = self.patch_size, pw = self.patch_size)
        return self.conv_out(x)