class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.skip = nn.Sequential()
if in_ch != out_ch:
self.skip = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch)
)
else:
self.skip = None
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.block(x)
if self.skip is not None:
identity = self.skip(x)
out += identity
out = self.relu(out)
return out
class Encoder(nn.Module):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
super(Encoder, self).__init__()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
for idx in range(len(features) - 1):
self.downs.append(ResBlock(in_channels, features[idx]))
in_channels = features[idx]
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
skip_connections = skip_connections[::-1]
return x, skip_connections
class Decoder(nn.Module):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
super(Decoder, self).__init__()
self.ups = nn.ModuleList()
for idx in reversed(range(len(features) - 1)):
self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2))
self.ups.append(ResBlock(features[idx]*2, features[idx]))
def forward(self, x, skip_connections):
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, skip_connection.shape[2:])
concat_skip = torch.cat([x, skip_connection], dim=1)
x = self.ups[idx+1](concat_skip)
return x
class baseline_UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256], device=DEVICE):
super(baseline_UNet, self).__init__()
self.encoder = Encoder(in_channels, out_channels, features).to(device)
self.decoder = Decoder(in_channels, out_channels, features).to(device)
self.bottleneck = ResBlock(features[-2], features[-1])
self.out = nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0, stride=1)
def forward(self, x):
x, skip_connections = self.encoder(x)
x = self.bottleneck(x)
x = self.decoder(x, skip_connections)
x = self.out(x)
return x