Untitled

 avatar
unknown
plain_text
a month ago
2.6 kB
3
Indexable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1)
        self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6)
        self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12)
        self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        )
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        size = x.shape[2:]
        out1 = self.atrous_block1(x)
        out2 = self.atrous_block2(x)
        out3 = self.atrous_block3(x)
        out4 = self.atrous_block4(x)
        out5 = self.global_avg_pool(x)
        out5 = F.interpolate(out5, size=size, mode='bilinear', align_corners=False)

        out = torch.cat([out1, out2, out3, out4, out5], dim=1)
        out = self.conv1(out)
        out = self.bn1(out)
        return self.relu(out)


class DeepLabV3(nn.Module):
    def __init__(self, num_classes, backbone='resnet50'):
        super(DeepLabV3, self).__init__()

        # Encoder (backbone)
        if backbone == 'resnet50':
            resnet = resnet50(weights=None)  # No pretraining
            self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # Remove FC and AvgPool layers
            in_channels = 2048
        else:
            raise NotImplementedError("Only ResNet50 is supported as a backbone.")

        # ASPP module
        self.aspp = ASPP(in_channels, 256)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
        )

    def forward(self, x):
        size = x.shape[2:]  # Original input size
        x = self.backbone(x)
        x = self.aspp(x)
        x = self.decoder(x)
        x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
        return x


num_classes = 34
model = DeepLabV3(num_classes=num_classes).to(device)
Leave a Comment