Untitled
import torch import torch.nn as nn import torch.nn.functional as F class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, dilation=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) return F.relu(out) class ASPP(nn.Module): def __init__(self, in_channels, out_channels): super(ASPP, self).__init__() self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6, bias=False) self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12, bias=False) self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18, bias=False) self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) ) self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) def forward(self, x): size = x.shape[2:] out1 = F.relu(self.atrous_block1(x)) out2 = F.relu(self.atrous_block2(x)) out3 = F.relu(self.atrous_block3(x)) out4 = F.relu(self.atrous_block4(x)) out5 = F.interpolate(self.global_avg_pool(x), size=size, mode='bilinear', align_corners=False) out = torch.cat([out1, out2, out3, out4, out5], dim=1) out = F.relu(self.bn1(self.conv1(out))) return out class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super(DecoderBlock, self).__init__() self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.conv = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = self.upconv(x) x = self.conv(x) return x class SemanticSegmentationNet(nn.Module): def __init__(self, num_classes): super(SemanticSegmentationNet, self).__init__() self.encoder1 = BasicBlock(3, 64, stride=2) self.encoder2 = BasicBlock(64, 128, stride=2) self.encoder3 = BasicBlock(128, 256, stride=2) self.encoder4 = BasicBlock(256, 512, stride=2) self.aspp = ASPP(512, 512) self.decoder4 = DecoderBlock(512, 256) self.decoder3 = DecoderBlock(256, 128) self.decoder2 = DecoderBlock(128, 64) self.decoder1 = DecoderBlock(64, 64) self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x): # Encoder e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # ASPP aspp = self.aspp(e4) # Decoder d4 = self.decoder4(aspp) d3 = self.decoder3(d4 + e3) d2 = self.decoder2(d3 + e2) d1 = self.decoder1(d2 + e1) # Final output out = self.final_conv(d1) return out # Test the model model = SemanticSegmentationNet(num_classes=34) x = torch.randn(1, 3, 256, 512) # Example input output = model(x) print("Output shape:", output.shape)
Leave a Comment