Untitled
unknown
python
3 years ago
2.2 kB
11
Indexable
import torch
class MultiHeadAttention(torch.nn.Module):
def __init__(self, num_heads, input_size, key_size, value_size):
super().__init__()
# Define the number of heads and the input size
self.num_heads = num_heads
self.input_size = input_size
# Define the linear transformations for the keys, values, and queries
self.key_projection = torch.nn.Linear(input_size, key_size)
self.value_projection = torch.nn.Linear(input_size, value_size)
self.query_projection = torch.nn.Linear(input_size, key_size)
# Define the linear transformation for the output
self.output_projection = torch.nn.Linear(value_size, input_size)
def forward(self, inputs, attention_mask):
# Project the inputs to generate the keys, values, and queries
keys = self.key_projection(inputs)
values = self.value_projection(inputs)
queries = self.query_projection(inputs)
# Split the keys, values, and queries into multiple heads
keys = torch.split(keys, self.num_heads, dim=-1)
values = torch.split(values, self.num_heads, dim=-1)
queries = torch.split(queries, self.num_heads, dim=-1)
# Compute the attention scores for each head
scores = []
for i in range(self.num_heads):
scores.append(torch.matmul(queries[i], keys[i].transpose(-2, -1)) / math.sqrt(self.key_size))
# Apply the attention mask to the scores
scores = [attention_mask * score for score in scores]
# Normalize the scores using softmax
scores = [torch.nn.functional.softmax(score, dim=-1) for score in scores]
# Apply the attention scores to the values
attended_values = []
for i in range(self.num_heads):
attended_values.append(torch.matmul(scores[i], values[i]))
# Concatenate the attended values and project them to the output size
attended_values = torch.cat(attended_values, dim=-1)
attended_values = self.output_projection(attended_values)
return attended_values
Editor is loading...