class conv2DBatchNormRelu(nn.Module):
def __init__(self, in_channels: int, n_filters: int, k_size: int, stride: int, padding: int):
super(conv2DBatchNormRelu, self).__init__()
self.unit = nn.Sequential(
nn.Conv2d(in_channels, n_filters, kernel_size=k_size, padding=padding, stride=stride),
nn.BatchNorm2d(n_filters),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
return self.unit(inputs)
class SegNet(nn.Module):
def __init__(self):
super().__init__()
# encoder (downsampling)
# Each enc_conv/dec_conv block should look like this:
# nn.Sequential(
# nn.Conv2d(...),
# ... (2 or 3 conv layers with relu and batchnorm),
# )
self.enc_conv0 = nn.Sequential(
conv2DBatchNormRelu(3, 64, 3, 1, 1),
conv2DBatchNormRelu(64, 64, 3, 1, 1)
)
self.pool0 = nn.MaxPool2d(2, 2, return_indices=True) # 256 -> 128
self.enc_conv1 = nn.Sequential(
conv2DBatchNormRelu(64, 128, 3, 1, 1),
conv2DBatchNormRelu(128, 128, 3, 1, 1),
)
self.pool1 = nn.MaxPool2d(2, 2, return_indices=True) # 128 -> 64
self.enc_conv2 = nn.Sequential(
conv2DBatchNormRelu(128, 256, 3, 1, 1),
conv2DBatchNormRelu(256, 256, 3, 1, 1),
conv2DBatchNormRelu(256, 256, 3, 1, 1)
)
self.pool2 = nn.MaxPool2d(2, 2, return_indices=True) # 64 -> 32
self.enc_conv3 = nn.Sequential(
conv2DBatchNormRelu(256, 512, 3, 1, 1),
conv2DBatchNormRelu(512, 512, 3, 1, 1),
conv2DBatchNormRelu(512, 512, 3, 1, 1)
)
self.pool3 = nn.MaxPool2d(2, 2, return_indices=True) # 32 -> 16
# bottleneck
self.bottle_neck = nn.Sequential(
conv2DBatchNormRelu(512, 1024, 1, 1, 0),
conv2DBatchNormRelu(1024, 512, 1, 1, 0)
)
# decoder (upsampling)
self.upsample0 = nn.MaxUnpool2d(2, 2) # 16 -> 32
self.dec_conv0 = nn.Sequential(
conv2DBatchNormRelu(512, 256, 3, 1, 1),
conv2DBatchNormRelu(256, 256, 3, 1, 1),
conv2DBatchNormRelu(256, 256, 3, 1, 1),
)
self.upsample1 = nn.MaxUnpool2d(2, 2) # 32 -> 64
self.dec_conv1 = nn.Sequential(
conv2DBatchNormRelu(256, 128, 3, 1, 1),
conv2DBatchNormRelu(128, 128, 3, 1, 1),
conv2DBatchNormRelu(128, 128, 3, 1, 1),
)
self.upsample2 = nn.MaxUnpool2d(2, 2) # 64 -> 128
self.dec_conv2 = nn.Sequential(
conv2DBatchNormRelu(128, 64, 3, 1, 1),
conv2DBatchNormRelu(64, 64, 3, 1, 1),
)
self.upsample3 = nn.MaxUnpool2d(2, 2) # 128 -> 256
self.dec_conv3 = nn.Sequential(
conv2DBatchNormRelu(64, 1, 3, 1, 1),
conv2DBatchNormRelu(1, 1, 3, 1, 1),
nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1),
)
def forward(self, x):
# encoder
e0, ind0 = self.pool0(self.enc_conv0(x))
e1, ind1 = self.pool1(self.enc_conv1(e0))
e2, ind2 = self.pool2(self.enc_conv2(e1))
e3, ind3 = self.pool3(self.enc_conv3(e2))
# bottleneck
bottle_neck = self.bottle_neck(e3)
# decoder
d0 = self.dec_conv0(self.upsample0(bottle_neck, ind0))
d1 = self.dec_conv1(self.upsample1(d0, ind1))
d2 = self.dec_conv2(self.upsample2(d1, ind2))
d3 = self.dec_conv3(self.upsample3(d2, ind3))
return d3