Untitled

 avatar
unknown
python
10 days ago
1.1 kB
10
Indexable
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.dropout = nn.Dropout2d(0.3)  

        self.enc1 = DoubleConv(3, 16)   
        self.enc2 = DoubleConv(16, 32)  
        
        self.bottleneck = DoubleConv(32, 64)
        
        self.dec2 = DoubleConv(64 + 32, 32)
        self.dec1 = DoubleConv(32 + 16, 16)
        
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        enc1 = self.enc1(x)
        x = self.pool(enc1)
        x = self.dropout(x)
        
        enc2 = self.enc2(x)
        x = self.pool(enc2)
        x = self.dropout(x)
        
        x = self.bottleneck(x)
        x = self.dropout(x)
        
        x = self.up(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)
        x = self.dropout(x)
        
        x = self.up(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)
        
        x = self.final_conv(x)
        return torch.sigmoid(x)
Leave a Comment