Untitled

 avatar
unknown
python
2 years ago
1.8 kB
5
Indexable
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_head):
        super().__init__()
        assert d_model % num_head == 0

        self.num_head = num_head
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask):
        """
        x: [batch, seq, d_model]
        """
        qkv = self.qkv_proj(x)
        q, k, v = einops.rearrange(qkv, "b t (k h d) -> b k h t d", k=3, h=self.num_head).unbind(1)
        # force flash attention, it will raise error if flash cannot be applied
        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
            attn_v = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, dropout_p=0.0, attn_mask=attn_mask
            )
        attn_v = einops.rearrange(attn_v, "b h t d -> b t (h d)")
        return self.out_proj(attn_v)


class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_head, ff_factor, dropout):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_head)

        self.layer_norm1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, ff_factor * d_model)
        self.linear2 = nn.Linear(ff_factor * d_model, d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

        self.dropout = dropout

    def forward(self, x, attn_mask):
        x = x + nn.functional.dropout(self.mha(self.layer_norm1(x), attn_mask), p=self.dropout)
        x = x + nn.functional.dropout(self._ff_block(self.layer_norm2(x)), p=self.dropout)
        return x

    def _ff_block(self, x):
        x = nn.functional.dropout(nn.functional.relu(self.linear1(x)), self.dropout)
        x = self.linear2(x)
        return x
Editor is loading...
Leave a Comment