Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.3 kB
14
Indexable
Never
import torch
from einops import rearrange


class GPTMLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_in = torch.nn.Linear(1024, 4096, bias=False)
        self.fc_out = torch.nn.Linear(4096, 1024, bias=False)

    def forward(self, x):
        x = self.fc_in(x)
        x = torch.nn.functional.gelu(x)
        return self.fc_out(x)


class GPTSelfAttn(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv_proj = torch.nn.Linear(1024, 1024 * 3, bias=False)
        self.out_proj = torch.nn.Linear(1024, 1024, bias=False)
        self.causal_mask = (
            1 - torch.tril(torch.ones((2048, 2048))).view(1, 1, 2048, 2048)
        ) * -10000.0
        self.scale = 1024**-0.5

    def forward(self, x):
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = rearrange(q, "b n (h d) -> b h n d", h=16)
        k = rearrange(k, "b n (h d) -> b h n d", h=16)
        v = rearrange(v, "b n (h d) -> b h n d", h=16)

        attn = torch.einsum("b h q d, b h k d -> b h q k", q, k) * self.scale
        attn = attn + self.causal_mask[:, :, : attn.shape[-2], : attn.shape[-1]]
        attn = torch.nn.functional.softmax(attn, dim=-1)

        x = torch.einsum("b h q k, b h k d -> b h q d", attn, v)
        x = rearrange(x, "b h n d -> b n (h d)")

        return self.out_proj(x)


class GPTBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = torch.nn.LayerNorm(1024)
        self.ln2 = torch.nn.LayerNorm(1024)
        self.attn = GPTSelfAttn()
        self.mlp = GPTMLP()

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = torch.nn.Embedding(50256, 1024)
        self.pos_emb = torch.nn.Embedding(2048, 1024)
        self.blocks = torch.nn.Sequential(*[GPTBlock() for _ in range(24)])
        self.ln_f = torch.nn.LayerNorm(1024)
        self.lm_head = torch.nn.Linear(1024, 50256, bias=False)

    def forward(self, x):
        b, n = x.shape
        x = self.tok_emb(x) + self.pos_emb(torch.arange(n, device=x.device))
        x = self.ln_f(self.blocks(x))
        return self.lm_head(x)