NLP with transformers MHSA

 avatar
unknown
python
2 years ago
458 B
9
Indexable
class MultiHeadAttention(nn.Module):
def init_ _(self, config):
super()._init_()
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim= embed_dim // num_heads
self.heads = nn.Modulelist(
[AttentionHead (embed_dim, head_dim) for
)
self.output_linear = nn.Linear(embed_dim, embed_dim)
in range(num_heads)]

def forward(self, hidden_state):
x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
x = self.output_linear(x)
return x
Editor is loading...