Untitled
unknown
plain_text
3 years ago
2.3 kB
22
Indexable
import torch
from einops import rearrange
class GPTMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_in = torch.nn.Linear(1024, 4096, bias=False)
self.fc_out = torch.nn.Linear(4096, 1024, bias=False)
def forward(self, x):
x = self.fc_in(x)
x = torch.nn.functional.gelu(x)
return self.fc_out(x)
class GPTSelfAttn(torch.nn.Module):
def __init__(self):
super().__init__()
self.qkv_proj = torch.nn.Linear(1024, 1024 * 3, bias=False)
self.out_proj = torch.nn.Linear(1024, 1024, bias=False)
self.causal_mask = (
1 - torch.tril(torch.ones((2048, 2048))).view(1, 1, 2048, 2048)
) * -10000.0
self.scale = 1024**-0.5
def forward(self, x):
q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=16)
k = rearrange(k, "b n (h d) -> b h n d", h=16)
v = rearrange(v, "b n (h d) -> b h n d", h=16)
attn = torch.einsum("b h q d, b h k d -> b h q k", q, k) * self.scale
attn = attn + self.causal_mask[:, :, : attn.shape[-2], : attn.shape[-1]]
attn = torch.nn.functional.softmax(attn, dim=-1)
x = torch.einsum("b h q k, b h k d -> b h q d", attn, v)
x = rearrange(x, "b h n d -> b n (h d)")
return self.out_proj(x)
class GPTBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.ln1 = torch.nn.LayerNorm(1024)
self.ln2 = torch.nn.LayerNorm(1024)
self.attn = GPTSelfAttn()
self.mlp = GPTMLP()
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(torch.nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = torch.nn.Embedding(50256, 1024)
self.pos_emb = torch.nn.Embedding(2048, 1024)
self.blocks = torch.nn.Sequential(*[GPTBlock() for _ in range(24)])
self.ln_f = torch.nn.LayerNorm(1024)
self.lm_head = torch.nn.Linear(1024, 50256, bias=False)
def forward(self, x):
b, n = x.shape
x = self.tok_emb(x) + self.pos_emb(torch.arange(n, device=x.device))
x = self.ln_f(self.blocks(x))
return self.lm_head(x)
Editor is loading...