Untitled
unknown
python
9 months ago
1.1 kB
12
Indexable
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.dropout = nn.Dropout2d(0.3)
self.enc1 = DoubleConv(3, 16)
self.enc2 = DoubleConv(16, 32)
self.bottleneck = DoubleConv(32, 64)
self.dec2 = DoubleConv(64 + 32, 32)
self.dec1 = DoubleConv(32 + 16, 16)
self.final_conv = nn.Conv2d(16, 1, kernel_size=1)
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
enc1 = self.enc1(x)
x = self.pool(enc1)
x = self.dropout(x)
enc2 = self.enc2(x)
x = self.pool(enc2)
x = self.dropout(x)
x = self.bottleneck(x)
x = self.dropout(x)
x = self.up(x)
x = torch.cat([x, enc2], dim=1)
x = self.dec2(x)
x = self.dropout(x)
x = self.up(x)
x = torch.cat([x, enc1], dim=1)
x = self.dec1(x)
x = self.final_conv(x)
return torch.sigmoid(x)Editor is loading...
Leave a Comment