Untitled
unknown
python
2 years ago
1.8 kB
5
Indexable
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_head): super().__init__() assert d_model % num_head == 0 self.num_head = num_head self.qkv_proj = nn.Linear(d_model, 3 * d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, attn_mask): """ x: [batch, seq, d_model] """ qkv = self.qkv_proj(x) q, k, v = einops.rearrange(qkv, "b t (k h d) -> b k h t d", k=3, h=self.num_head).unbind(1) # force flash attention, it will raise error if flash cannot be applied with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): attn_v = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=0.0, attn_mask=attn_mask ) attn_v = einops.rearrange(attn_v, "b h t d -> b t (h d)") return self.out_proj(attn_v) class TransformerLayer(nn.Module): def __init__(self, d_model, num_head, ff_factor, dropout): super().__init__() self.mha = MultiHeadAttention(d_model, num_head) self.layer_norm1 = nn.LayerNorm(d_model) self.linear1 = nn.Linear(d_model, ff_factor * d_model) self.linear2 = nn.Linear(ff_factor * d_model, d_model) self.layer_norm2 = nn.LayerNorm(d_model) self.dropout = dropout def forward(self, x, attn_mask): x = x + nn.functional.dropout(self.mha(self.layer_norm1(x), attn_mask), p=self.dropout) x = x + nn.functional.dropout(self._ff_block(self.layer_norm2(x)), p=self.dropout) return x def _ff_block(self, x): x = nn.functional.dropout(nn.functional.relu(self.linear1(x)), self.dropout) x = self.linear2(x) return x
Editor is loading...
Leave a Comment