Untitled
unknown
python
2 years ago
1.8 kB
10
Indexable
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_head):
super().__init__()
assert d_model % num_head == 0
self.num_head = num_head
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, attn_mask):
"""
x: [batch, seq, d_model]
"""
qkv = self.qkv_proj(x)
q, k, v = einops.rearrange(qkv, "b t (k h d) -> b k h t d", k=3, h=self.num_head).unbind(1)
# force flash attention, it will raise error if flash cannot be applied
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
attn_v = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, attn_mask=attn_mask
)
attn_v = einops.rearrange(attn_v, "b h t d -> b t (h d)")
return self.out_proj(attn_v)
class TransformerLayer(nn.Module):
def __init__(self, d_model, num_head, ff_factor, dropout):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_head)
self.layer_norm1 = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, ff_factor * d_model)
self.linear2 = nn.Linear(ff_factor * d_model, d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = dropout
def forward(self, x, attn_mask):
x = x + nn.functional.dropout(self.mha(self.layer_norm1(x), attn_mask), p=self.dropout)
x = x + nn.functional.dropout(self._ff_block(self.layer_norm2(x)), p=self.dropout)
return x
def _ff_block(self, x):
x = nn.functional.dropout(nn.functional.relu(self.linear1(x)), self.dropout)
x = self.linear2(x)
return xEditor is loading...
Leave a Comment