NLP with transformers MHSA
unknown
python
2 years ago
458 B
9
Indexable
class MultiHeadAttention(nn.Module): def init_ _(self, config): super()._init_() embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim= embed_dim // num_heads self.heads = nn.Modulelist( [AttentionHead (embed_dim, head_dim) for ) self.output_linear = nn.Linear(embed_dim, embed_dim) in range(num_heads)] def forward(self, hidden_state): x = torch.cat([h(hidden_state) for h in self.heads], dim=-1) x = self.output_linear(x) return x
Editor is loading...