Untitled
unknown
python
3 years ago
3.4 kB
17
Indexable
class ResBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.skip = nn.Sequential() if in_ch != out_ch: self.skip = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_ch) ) else: self.skip = None self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False), nn.BatchNorm2d(out_ch) ) self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = x out = self.block(x) if self.skip is not None: identity = self.skip(x) out += identity out = self.relu(out) return out class Encoder(nn.Module): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]): super(Encoder, self).__init__() self.downs = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) for idx in range(len(features) - 1): self.downs.append(ResBlock(in_channels, features[idx])) in_channels = features[idx] def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) skip_connections = skip_connections[::-1] return x, skip_connections class Decoder(nn.Module): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]): super(Decoder, self).__init__() self.ups = nn.ModuleList() for idx in reversed(range(len(features) - 1)): self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2)) self.ups.append(ResBlock(features[idx]*2, features[idx])) def forward(self, x, skip_connections): for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = TF.resize(x, skip_connection.shape[2:]) concat_skip = torch.cat([x, skip_connection], dim=1) x = self.ups[idx+1](concat_skip) return x class baseline_UNet(nn.Module): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256], device=DEVICE): super(baseline_UNet, self).__init__() self.encoder = Encoder(in_channels, out_channels, features).to(device) self.decoder = Decoder(in_channels, out_channels, features).to(device) self.bottleneck = ResBlock(features[-2], features[-1]) self.out = nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0, stride=1) def forward(self, x): x, skip_connections = self.encoder(x) x = self.bottleneck(x) x = self.decoder(x, skip_connections) x = self.out(x) return x
Editor is loading...