Untitled
unknown
python
a year ago
10 kB
5
Indexable
def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x def Upsample(dim, dim_out=None): return nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), ) def Downsample(dim, dim_out=None): # No More Strided Convolutions or Pooling return nn.Sequential( Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(dim * 4, default(dim_out, dim), 1), ) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class WeightStandardizedConv2d(nn.Conv2d): """ https://arxiv.org/abs/1903.10520 weight standardization purportedly works synergistically with group normalization """ def forward(self, x): eps = 1e-5 if x.dtype == torch.float32 else 1e-3 weight = self.weight mean = reduce(weight, "o ... -> o 1 1 1", "mean") var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) normalized_weight = (weight - mean) * (var + eps).rsqrt() return F.conv2d( x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): """https://arxiv.org/abs/1512.03385""" def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): scale_shift = None if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) time_emb = rearrange(time_emb, "b c -> b c 1 1") scale_shift = time_emb.chunk(2, dim=1) h = self.block1(x, scale_shift=scale_shift) h = self.block2(h) return h + self.res_conv(x) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.GroupNorm(1, dim) def forward(self, x): x = self.norm(x) return self.fn(x) class Unet(nn.Module): def __init__( self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8, 8), channels=3, self_condition=False, resnet_block_groups=4, att_res=32, att_heads=4, ): super().__init__() # determine dimensions self.channels = channels self.self_condition = self_condition # input_channels = channels * (2 if self_condition else 1) input_channels = channels if not self_condition else channels + 1 init_dim = default(init_dim, dim) self.init_conv = nn.Conv2d( input_channels, init_dim, 1, padding=0 ) # changed to 1 and 0 from 7,3 dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) block_klass = partial(ResnetBlock, groups=resnet_block_groups) # time embeddings time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), ) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ block_klass(dim_in, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual( PreNorm(dim_in, LinearAttention(dim_in, att_heads, att_res)) ), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual( PreNorm(mid_dim, Attention(mid_dim, att_heads, att_res)) ) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) self.ups.append( nn.ModuleList( [ block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), Residual( PreNorm( dim_out, LinearAttention(dim_out, att_heads, att_res) ) ), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1), ] ) ) self.out_dim = default(out_dim, channels) self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) def forward(self, x, time, x_self_cond=None): if self.self_condition: x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x_self_cond, x), dim=1) x = self.init_conv(x) r = x.clone() t = self.time_mlp(time) h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) h.append(x) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) x = torch.cat((x, h.pop()), dim=1) x = block2(x, t) x = attn(x) x = upsample(x) x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t) return self.final_conv(x)
Editor is loading...
Leave a Comment