Untitled
unknown
plain_text
2 years ago
2.3 kB
16
Indexable
import torch from einops import rearrange class GPTMLP(torch.nn.Module): def __init__(self): super().__init__() self.fc_in = torch.nn.Linear(1024, 4096, bias=False) self.fc_out = torch.nn.Linear(4096, 1024, bias=False) def forward(self, x): x = self.fc_in(x) x = torch.nn.functional.gelu(x) return self.fc_out(x) class GPTSelfAttn(torch.nn.Module): def __init__(self): super().__init__() self.qkv_proj = torch.nn.Linear(1024, 1024 * 3, bias=False) self.out_proj = torch.nn.Linear(1024, 1024, bias=False) self.causal_mask = ( 1 - torch.tril(torch.ones((2048, 2048))).view(1, 1, 2048, 2048) ) * -10000.0 self.scale = 1024**-0.5 def forward(self, x): q, k, v = self.qkv_proj(x).chunk(3, dim=-1) q = rearrange(q, "b n (h d) -> b h n d", h=16) k = rearrange(k, "b n (h d) -> b h n d", h=16) v = rearrange(v, "b n (h d) -> b h n d", h=16) attn = torch.einsum("b h q d, b h k d -> b h q k", q, k) * self.scale attn = attn + self.causal_mask[:, :, : attn.shape[-2], : attn.shape[-1]] attn = torch.nn.functional.softmax(attn, dim=-1) x = torch.einsum("b h q k, b h k d -> b h q d", attn, v) x = rearrange(x, "b h n d -> b n (h d)") return self.out_proj(x) class GPTBlock(torch.nn.Module): def __init__(self): super().__init__() self.ln1 = torch.nn.LayerNorm(1024) self.ln2 = torch.nn.LayerNorm(1024) self.attn = GPTSelfAttn() self.mlp = GPTMLP() def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(torch.nn.Module): def __init__(self): super().__init__() self.tok_emb = torch.nn.Embedding(50256, 1024) self.pos_emb = torch.nn.Embedding(2048, 1024) self.blocks = torch.nn.Sequential(*[GPTBlock() for _ in range(24)]) self.ln_f = torch.nn.LayerNorm(1024) self.lm_head = torch.nn.Linear(1024, 50256, bias=False) def forward(self, x): b, n = x.shape x = self.tok_emb(x) + self.pos_emb(torch.arange(n, device=x.device)) x = self.ln_f(self.blocks(x)) return self.lm_head(x)
Editor is loading...