Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
3.4 kB
15
Indexable
Never
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        
        self.skip = nn.Sequential()
        
        if in_ch != out_ch:
            self.skip = nn.Sequential( 
                    nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
                    nn.BatchNorm2d(out_ch)
            )
        else:
            self.skip = None
            
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch)
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        out = self.block(x)
        
        if self.skip is not None:
            identity = self.skip(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class Encoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
        super(Encoder, self).__init__()

        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for idx in range(len(features) - 1):
            self.downs.append(ResBlock(in_channels, features[idx]))
            in_channels = features[idx]

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        skip_connections = skip_connections[::-1]

        return x, skip_connections

class Decoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
        super(Decoder, self).__init__()

        self.ups = nn.ModuleList()

        for idx in reversed(range(len(features) - 1)):
            self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2))
            self.ups.append(ResBlock(features[idx]*2, features[idx]))

    def forward(self, x, skip_connections):
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, skip_connection.shape[2:])

            concat_skip = torch.cat([x, skip_connection], dim=1)

            x = self.ups[idx+1](concat_skip)

        return x
    
class baseline_UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256], device=DEVICE):
        super(baseline_UNet, self).__init__()
        
        self.encoder = Encoder(in_channels, out_channels, features).to(device)
        self.decoder = Decoder(in_channels, out_channels, features).to(device)

        self.bottleneck = ResBlock(features[-2], features[-1])      
        self.out = nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0, stride=1)
    
    def forward(self, x):
        x, skip_connections = self.encoder(x)

        x = self.bottleneck(x)
        
        x = self.decoder(x, skip_connections)

        x = self.out(x)
        
        return x