mail@pastecode.io avatar
2 years ago
2.2 kB
import torch

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, input_size, key_size, value_size):
        # Define the number of heads and the input size
        self.num_heads = num_heads
        self.input_size = input_size
        # Define the linear transformations for the keys, values, and queries
        self.key_projection = torch.nn.Linear(input_size, key_size)
        self.value_projection = torch.nn.Linear(input_size, value_size)
        self.query_projection = torch.nn.Linear(input_size, key_size)
        # Define the linear transformation for the output
        self.output_projection = torch.nn.Linear(value_size, input_size)
    def forward(self, inputs, attention_mask):
        # Project the inputs to generate the keys, values, and queries
        keys = self.key_projection(inputs)
        values = self.value_projection(inputs)
        queries = self.query_projection(inputs)
        # Split the keys, values, and queries into multiple heads
        keys = torch.split(keys, self.num_heads, dim=-1)
        values = torch.split(values, self.num_heads, dim=-1)
        queries = torch.split(queries, self.num_heads, dim=-1)
        # Compute the attention scores for each head
        scores = []
        for i in range(self.num_heads):
            scores.append(torch.matmul(queries[i], keys[i].transpose(-2, -1)) / math.sqrt(self.key_size))
        # Apply the attention mask to the scores
        scores = [attention_mask * score for score in scores]
        # Normalize the scores using softmax
        scores = [torch.nn.functional.softmax(score, dim=-1) for score in scores]
        # Apply the attention scores to the values
        attended_values = []
        for i in range(self.num_heads):
            attended_values.append(torch.matmul(scores[i], values[i]))
        # Concatenate the attended values and project them to the output size
        attended_values = torch.cat(attended_values, dim=-1)
        attended_values = self.output_projection(attended_values)
        return attended_values