Untitled
unknown
python
3 years ago
2.2 kB
5
Indexable
import torch class MultiHeadAttention(torch.nn.Module): def __init__(self, num_heads, input_size, key_size, value_size): super().__init__() # Define the number of heads and the input size self.num_heads = num_heads self.input_size = input_size # Define the linear transformations for the keys, values, and queries self.key_projection = torch.nn.Linear(input_size, key_size) self.value_projection = torch.nn.Linear(input_size, value_size) self.query_projection = torch.nn.Linear(input_size, key_size) # Define the linear transformation for the output self.output_projection = torch.nn.Linear(value_size, input_size) def forward(self, inputs, attention_mask): # Project the inputs to generate the keys, values, and queries keys = self.key_projection(inputs) values = self.value_projection(inputs) queries = self.query_projection(inputs) # Split the keys, values, and queries into multiple heads keys = torch.split(keys, self.num_heads, dim=-1) values = torch.split(values, self.num_heads, dim=-1) queries = torch.split(queries, self.num_heads, dim=-1) # Compute the attention scores for each head scores = [] for i in range(self.num_heads): scores.append(torch.matmul(queries[i], keys[i].transpose(-2, -1)) / math.sqrt(self.key_size)) # Apply the attention mask to the scores scores = [attention_mask * score for score in scores] # Normalize the scores using softmax scores = [torch.nn.functional.softmax(score, dim=-1) for score in scores] # Apply the attention scores to the values attended_values = [] for i in range(self.num_heads): attended_values.append(torch.matmul(scores[i], values[i])) # Concatenate the attended values and project them to the output size attended_values = torch.cat(attended_values, dim=-1) attended_values = self.output_projection(attended_values) return attended_values
Editor is loading...