Untitled

 avatar
unknown
plain_text
a month ago
4.4 kB
2
Indexable
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride,
            padding=dilation, dilation=dilation, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1,
            padding=dilation, dilation=dilation, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6, bias=False)
        self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12, bias=False)
        self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18, bias=False)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        size = x.shape[2:]
        out1 = F.relu(self.atrous_block1(x))
        out2 = F.relu(self.atrous_block2(x))
        out3 = F.relu(self.atrous_block3(x))
        out4 = F.relu(self.atrous_block4(x))
        out5 = F.interpolate(self.global_avg_pool(x), size=size, mode='bilinear', align_corners=False)
        out = torch.cat([out1, out2, out3, out4, out5], dim=1)
        out = F.relu(self.bn1(self.conv1(out)))
        return out

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.upconv(x)
        x = self.conv(x)
        return x

class SemanticSegmentationNet(nn.Module):
    def __init__(self, num_classes):
        super(SemanticSegmentationNet, self).__init__()
        self.encoder1 = BasicBlock(3, 64, stride=2)
        self.encoder2 = BasicBlock(64, 128, stride=2)
        self.encoder3 = BasicBlock(128, 256, stride=2)
        self.encoder4 = BasicBlock(256, 512, stride=2)

        self.aspp = ASPP(512, 512)

        self.decoder4 = DecoderBlock(512, 256)
        self.decoder3 = DecoderBlock(256, 128)
        self.decoder2 = DecoderBlock(128, 64)
        self.decoder1 = DecoderBlock(64, 64)

        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        # ASPP
        aspp = self.aspp(e4)

        # Decoder
        d4 = self.decoder4(aspp)
        d3 = self.decoder3(d4 + e3)
        d2 = self.decoder2(d3 + e2)
        d1 = self.decoder1(d2 + e1)

        # Final output
        out = self.final_conv(d1)
        return out

# Test the model

model = SemanticSegmentationNet(num_classes=34)
x = torch.randn(1, 3, 256, 512)  # Example input
output = model(x)
print("Output shape:", output.shape)
Leave a Comment