unet
unknown
python
5 months ago
8.1 kB
3
Indexable
import torch import torch.nn as nn class SkipConnection(nn.Module): def __init__(self) -> None: super().__init__() self.skip_con = nn.Identity() def forward(self, x): return self.skip_con(x) class ResidualBlock(nn.Module): #TODO: Change this for a basic UNet def __init__( self, in_channels, out_channels, n_groups=32, ): super().__init__() self.norm1 = nn.GroupNorm(n_groups, in_channels) self.act1 = nn.SiLU() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1) ) self.norm2 = nn.GroupNorm(n_groups, out_channels) self.act2 = nn.SiLU() self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), ) if in_channels != out_channels: self.shortcut = nn.Conv2d( in_channels, out_channels, kernel_size=(1, 1) ) else: self.shortcut = nn.Identity() def forward(self, x): h = self.conv1(self.act1(self.norm1(x))) h = self.conv2(self.act2(self.norm2(h))) return h + self.shortcut(x) class AttentionBlock(nn.Module): def __init__(self, n_channels, n_heads=1, d_k=None): super().__init__() if d_k is None: d_k = n_channels // n_heads self.projection = nn.Linear(n_channels, n_heads * d_k * 3) self.output = nn.Linear(n_heads * d_k, n_channels) self.scale = d_k**-0.5 self.n_heads = n_heads self.d_k = d_k def forward(self, x: torch.Tensor): batch_size, n_channels, height, width = x.shape x = x.view(batch_size, n_channels, -1).permute(0, 2, 1) qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k) q, k, v = torch.chunk(qkv, 3, dim=-1) attn = torch.einsum("bihd,bjhd->bijh", q, k) * self.scale attn = attn.softmax(dim=2) res = torch.einsum("bijh,bjhd->bihd", attn, v) res = res.view(batch_size, -1, self.n_heads * self.d_k) res = self.output(res) res += x res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width) return res class DownBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, has_attn: bool, ): super().__init__() self.has_attn = has_attn self.res = ResidualBlock( in_channels, out_channels ) if has_attn: self.attn = AttentionBlock(out_channels) else: self.attn = nn.Identity() def forward(self, x: torch.Tensor): x = self.res(x) x = self.attn(x) return x class UpBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, has_attn: bool, ): super().__init__() add = out_channels self.has_attn = has_attn self.res = ResidualBlock( in_channels + add, out_channels ) if has_attn: self.attn = AttentionBlock(out_channels) else: self.attn = nn.Identity() def forward(self, x: torch.Tensor): x = self.res(x) x = self.attn(x) return x class MiddleBlock(nn.Module): def __init__(self, n_channels: int): super().__init__() self.res1 = ResidualBlock(n_channels, n_channels) self.attn = AttentionBlock(n_channels) self.res2 = ResidualBlock(n_channels, n_channels) def forward(self, x: torch.Tensor): x = self.res1(x) x = self.attn(x) x = self.res2(x) return x class Upsample(nn.Module): def __init__(self, n_channels): super().__init__() self.conv = nn.ConvTranspose2d( n_channels, n_channels, (4, 4), (2, 2), (1, 1) ) def forward(self, x: torch.Tensor): return self.conv(x) class Downsample(nn.Module): def __init__(self, n_channels): super().__init__() self.conv = nn.Conv2d( n_channels, n_channels, (3, 3), (2, 2), (1, 1) ) def forward(self, x: torch.Tensor): return self.conv(x) class UNet(nn.Module): def __init__(self, parameters: dict): super().__init__() n_channels = parameters["n_channels"] ch_mults = parameters["ch_mults"] is_attn = parameters["is_attn"] n_blocks = parameters["n_blocks"] image_channels = int(parameters["image_channels"]) self.start_img_channels = image_channels self.n_layers = len(ch_mults) n_resolutions = len(ch_mults) self.image_proj = nn.Conv2d( image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1) ) down = [] skip_con = [] self.cond_embs = [] out_channels = in_channels = n_channels for i in range(n_resolutions): out_channels = in_channels * ch_mults[i] for _ in range(n_blocks): down.append( DownBlock( in_channels, out_channels, is_attn[i], ) ) skip_con.append(SkipConnection()) in_channels = out_channels if i < n_resolutions - 1: down.append(Downsample(in_channels)) skip_con.append(SkipConnection()) self.skip_con = nn.ModuleList(skip_con) self.down = nn.ModuleList(down) self.middle = MiddleBlock(out_channels) up = [] in_channels = out_channels for i in reversed(range(n_resolutions)): out_channels = in_channels for _ in range(n_blocks): up.append( UpBlock( in_channels, out_channels, is_attn[i], ) ) out_channels = in_channels // ch_mults[i] up.append( UpBlock( in_channels, out_channels, is_attn[i], ) ) in_channels = out_channels if i > 0: up.append(Upsample(in_channels)) self.up = nn.ModuleList(up) self.norm = nn.GroupNorm(8, in_channels) self.act = nn.SiLU() self.final = nn.Conv2d( in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1) ) def prepare_start(self, x): x = self.image_proj(x) return x def encode(self, x): h = [x] for m, sc in zip(self.down, self.skip_con): x = m(x) h.append(sc(x)) return x, h def middle_block(self, x): x = self.middle(x) return x def decode(self, x, h): for m in self.up: if not isinstance(m, Upsample): s = h.pop() x = torch.cat((x, s), dim=1) x = m(x) return x def final_output(self, x): c = x.shape[1] f = self.act(self.norm(x)) f = self.final(f) return f def forward(self, x: torch.Tensor): x = self.prepare_start(x) x, h = self.encode(x) x = self.middle_block(x) x = self.decode(x, h) return self.final_output(x) if __name__ == "__main__": params = { "n_channels": 64, "ch_mults": [1,1], "is_attn": [False, False], "n_blocks": 1, "image_channels": 3 } ResUNet = UNet(params).cuda() x = torch.rand((1,3,16,16)).cuda() x = ResUNet(x) print(x.shape)
Editor is loading...
Leave a Comment