Untitled

mail@pastecode.io avatar
unknown
python
a year ago
8.3 kB
3
Indexable
Never
from typing import Optional, Tuple, Union

import random
import torch
from torch import nn
from torch.nn import functional as F


def _gqa_attn(query, key, value, attention_mask=None, scale_attn_weights=False,
              causal_mask_flag=False, dropout=0.0, local_window_size=None, sink_tokens=1):
    """Group Query Attention implementation."""

    # Check for potential issues before moving on
    if not query.ndim == key.ndim == value.ndim == 4:
        raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes "
                         f"{query.shape}, {key.shape}, and {value.shape}.")


    if sink_tokens > 0:
        # Concatenate sink tokens
        sink_query = query[:, :, :sink_tokens, :]
        sink_key = key[:, :, :sink_tokens, :]
        sink_value = value[:, :, :sink_tokens, :]

        query = torch.cat([sink_query, query], dim=2)
        key = torch.cat([sink_key, key], dim=2)
        value = torch.cat([sink_value, value], dim=2)

    print(f"query_len = {query.shape[2]}, key_len = {key.shape[2]}, value_len = {value.shape[2]}")


    """
    Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
    """
    batch_size, num_heads, query_len, query_dim = query.shape


    scale_factor = 1.0
    if scale_attn_weights:
        scale_factor /= float(value.size(-1)) ** 0.5

    # if self.scale_attn_by_inverse_layer_idx:
    #         attn_weights = attn_weights / float(self.layer_idx + 1)

    '''
    Scaling the query
    For now we have scale 1.0
    The scale factor has not been integrated into the attention function yet.
    '''

    query = query / scale_factor

    '''
    Determine the number of groups
    For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
    '''

    n_groups = query.size(1) // key.size(1)

    if n_groups > 1:
        query_shape = query.shape
        '''
        Lets say the number of group are 2 and head are 2,
        then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
        '''
        grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3])
        query_grouped = query.reshape(grouped_shape)

        '''
        query shape (batch_size, num_groups, num_heads, query_len, query_dim)
        '''
        attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1))

        '''
        attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
        '''
        attn_weights = attn_weights_grouped.sum(dim=1)

        '''
        attention weights shape: (batch_size, num_heads, query_len, key_len)
        '''

        #print("attn_weights:", attn_weights.shape)

    else:
        '''
        If the number of groups is 1, then we can use the normal attention function
        '''
        attn_weights = torch.matmul(query, key.transpose(-2, -1))


    # Incorporate local attention
    if local_window_size is not None:
        max_seq_len = query.size(-2)
        indices = torch.arange(max_seq_len).to(query.device)
        expanded_indices = indices.unsqueeze(-1).expand(max_seq_len, max_seq_len)
        distance_matrix = torch.abs(expanded_indices - indices.unsqueeze(0))
        print(f"distance_matrix = {distance_matrix}")
        attn_weights.masked_fill_(distance_matrix > local_window_size, float('-inf'))
        print(f"attn_weights AAA = {attn_weights}")


    if attention_mask is not None:
        # Apply the attention mask
        '''
        Input attention_mask shape: (batch_size, num_heads, query_len, key_len)
        '''
        print(f"attention_mask.shape = {attention_mask.shape}")
        print(f"attn_weights.shape = {attn_weights.shape}")
        attn_weights += attention_mask.unsqueeze(1)  # Unsqueeze to Add head dimension


    # Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.
    if causal_mask_flag:
        causal_mask = torch.ones((query.size(0), query.size(2), key.size(2)), device=query.device, dtype=torch.bool).tril_()
        # causal mask is lower traingular matrix with 1s on the lower triangle and 0s on the upper triangle
        mask_value = torch.finfo(attn_weights.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
        # print("mask_value:", mask_value)
        attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
        # print("attn_weights:", attn_weights)

    # This keeps the inputs in a range where softmax is numerically stable.
    attn_weights = torch.clamp(attn_weights, min=-1000, max=1000)

    # Softmax normalization to get the attention scores
    attn_weights = nn.functional.log_softmax(attn_weights, dim=-1)


    # Apply dropout if specified
    if dropout > 0.0:
        attn_weights = nn.functional.dropout(attn_weights, p=dropout)

    # Compute the output by multiplying the attention scores with the value tensor.
    attn_output = torch.matmul(attn_weights, value)

    return attn_output, attn_weights


if __name__ == '__main__':
    # Setting random seed for reproducibility
    torch.manual_seed(42)



    ## Uncomment here to test when num_groups > 1
    # Example tensor shapes
    batch_size = 2
    query_num_heads = 4
    key_value_num_heads = 2
    query_len = 6
    key_len = 6
    dim = 5

    # Generate example tensors
    query = torch.randn(batch_size, query_num_heads, query_len, dim)
    key = torch.randn(batch_size, key_value_num_heads, key_len, dim)
    value = torch.randn(batch_size, key_value_num_heads, key_len, dim)


    # Set number of sink tokens
    # For StreamingLLM, see paper : Efficient Streaming Language Models with Attention Sinks
    # http://arxiv.org/abs/2309.17453
    sink_tokens = 0


    shape = (batch_size, query_len+sink_tokens, query_len+sink_tokens)
    attention_mask = torch.zeros(shape)

    # Example attention mask, 1s indicate positions we want to include, -inf (or very large negative numbers) indicate positions to exclude
    indices = torch.randperm(attention_mask.numel())  # Assign value of -1e9 randomly
    attention_mask.view(-1)[indices[:5]] = -1e9


    # Run the function
    attn_output, attn_weights = _gqa_attn(query, key, value, attention_mask, scale_attn_weights=True, causal_mask_flag=True, dropout=0.1, local_window_size=3, sink_tokens=sink_tokens)

    print("Attention Output:", attn_output.shape)
    print("Attention Weights:", attn_weights.shape)


    # Print attn weights
    print(f"attn_weights = {attn_weights}")


    # Slice out sink token weights
    sink_attn = attn_weights[:, :, :sink_tokens, :]
    print(f"sink_attn = {sink_attn}")

    # Slice out main attn weights
    main_attn = attn_weights[:, :, sink_tokens:, :]
    print(f"main_attn = {main_attn}")



    # if __name__ == '__main__':
    # Setting random seed for reproducibility

    '''
    ## Uncomment here to test when num_groups = 1
    '''
    # torch.manual_seed(42)

    # # Example tensor shapes
    # batch_size = 2
    # query_num_heads = 2  # <-- Change this to 2
    # key_value_num_heads = 2
    # query_len = 3
    # key_len = 3
    # dim = 5

    # # Generate example tensors
    # query = torch.randn(batch_size, query_num_heads, query_len, dim)
    # key = torch.randn(batch_size, key_value_num_heads, key_len, dim)
    # value = torch.randn(batch_size, key_value_num_heads, key_len, dim)

    # # Example attention mask, 1s indicate positions we want to include, -inf (or very large negative numbers) indicate positions to exclude
    # attention_mask = torch.tensor([
    #     [[0., -1e9, -1e9], [0., 0., -1e9], [0., 0., 0.]],
    #     [[0., -1e9, -1e9], [0., 0., -1e9], [0., 0., 0.]]
    # ])

    # # Run the function
    # attn_output, attn_weights = _gqa_attn(query, key, value, attention_mask, scale_attn_weights=True, causal_mask_flag=True, dropout=0.1)

    # print("Attention Output:", attn_output.shape)
    # print("Attention Weights:", attn_weights.shape)