Untitled
unknown
plain_text
a year ago
3.8 kB
8
Indexable
Never
class conv2DBatchNormRelu(nn.Module): def __init__(self, in_channels: int, n_filters: int, k_size: int, stride: int, padding: int): super(conv2DBatchNormRelu, self).__init__() self.unit = nn.Sequential( nn.Conv2d(in_channels, n_filters, kernel_size=k_size, padding=padding, stride=stride), nn.BatchNorm2d(n_filters), nn.ReLU(inplace=True) ) def forward(self, inputs): return self.unit(inputs) class SegNet(nn.Module): def __init__(self): super().__init__() # encoder (downsampling) # Each enc_conv/dec_conv block should look like this: # nn.Sequential( # nn.Conv2d(...), # ... (2 or 3 conv layers with relu and batchnorm), # ) self.enc_conv0 = nn.Sequential( conv2DBatchNormRelu(3, 64, 3, 1, 1), conv2DBatchNormRelu(64, 64, 3, 1, 1) ) self.pool0 = nn.MaxPool2d(2, 2, return_indices=True) # 256 -> 128 self.enc_conv1 = nn.Sequential( conv2DBatchNormRelu(64, 128, 3, 1, 1), conv2DBatchNormRelu(128, 128, 3, 1, 1), ) self.pool1 = nn.MaxPool2d(2, 2, return_indices=True) # 128 -> 64 self.enc_conv2 = nn.Sequential( conv2DBatchNormRelu(128, 256, 3, 1, 1), conv2DBatchNormRelu(256, 256, 3, 1, 1), conv2DBatchNormRelu(256, 256, 3, 1, 1) ) self.pool2 = nn.MaxPool2d(2, 2, return_indices=True) # 64 -> 32 self.enc_conv3 = nn.Sequential( conv2DBatchNormRelu(256, 512, 3, 1, 1), conv2DBatchNormRelu(512, 512, 3, 1, 1), conv2DBatchNormRelu(512, 512, 3, 1, 1) ) self.pool3 = nn.MaxPool2d(2, 2, return_indices=True) # 32 -> 16 # bottleneck self.bottle_neck = nn.Sequential( conv2DBatchNormRelu(512, 1024, 1, 1, 0), conv2DBatchNormRelu(1024, 512, 1, 1, 0) ) # decoder (upsampling) self.upsample0 = nn.MaxUnpool2d(2, 2) # 16 -> 32 self.dec_conv0 = nn.Sequential( conv2DBatchNormRelu(512, 256, 3, 1, 1), conv2DBatchNormRelu(256, 256, 3, 1, 1), conv2DBatchNormRelu(256, 256, 3, 1, 1), ) self.upsample1 = nn.MaxUnpool2d(2, 2) # 32 -> 64 self.dec_conv1 = nn.Sequential( conv2DBatchNormRelu(256, 128, 3, 1, 1), conv2DBatchNormRelu(128, 128, 3, 1, 1), conv2DBatchNormRelu(128, 128, 3, 1, 1), ) self.upsample2 = nn.MaxUnpool2d(2, 2) # 64 -> 128 self.dec_conv2 = nn.Sequential( conv2DBatchNormRelu(128, 64, 3, 1, 1), conv2DBatchNormRelu(64, 64, 3, 1, 1), ) self.upsample3 = nn.MaxUnpool2d(2, 2) # 128 -> 256 self.dec_conv3 = nn.Sequential( conv2DBatchNormRelu(64, 1, 3, 1, 1), conv2DBatchNormRelu(1, 1, 3, 1, 1), nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(1), ) def forward(self, x): # encoder e0, ind0 = self.pool0(self.enc_conv0(x)) e1, ind1 = self.pool1(self.enc_conv1(e0)) e2, ind2 = self.pool2(self.enc_conv2(e1)) e3, ind3 = self.pool3(self.enc_conv3(e2)) # bottleneck bottle_neck = self.bottle_neck(e3) # decoder d0 = self.dec_conv0(self.upsample0(bottle_neck, ind0)) d1 = self.dec_conv1(self.upsample1(d0, ind1)) d2 = self.dec_conv2(self.upsample2(d1, ind2)) d3 = self.dec_conv3(self.upsample3(d2, ind3)) return d3