Untitled
class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() self.dropout = nn.Dropout2d(0.3) self.enc1 = DoubleConv(3, 16) self.enc2 = DoubleConv(16, 32) self.bottleneck = DoubleConv(32, 64) self.dec2 = DoubleConv(64 + 32, 32) self.dec1 = DoubleConv(32 + 16, 16) self.final_conv = nn.Conv2d(16, 1, kernel_size=1) self.pool = nn.MaxPool2d(2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def forward(self, x): enc1 = self.enc1(x) x = self.pool(enc1) x = self.dropout(x) enc2 = self.enc2(x) x = self.pool(enc2) x = self.dropout(x) x = self.bottleneck(x) x = self.dropout(x) x = self.up(x) x = torch.cat([x, enc2], dim=1) x = self.dec2(x) x = self.dropout(x) x = self.up(x) x = torch.cat([x, enc1], dim=1) x = self.dec1(x) x = self.final_conv(x) return torch.sigmoid(x)
Leave a Comment