Untitled
import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models.resnet import resnet50 class ASPP(nn.Module): def __init__(self, in_channels, out_channels): super(ASPP, self).__init__() self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1) self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6) self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12) self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18) self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) ) self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() def forward(self, x): size = x.shape[2:] out1 = self.atrous_block1(x) out2 = self.atrous_block2(x) out3 = self.atrous_block3(x) out4 = self.atrous_block4(x) out5 = self.global_avg_pool(x) out5 = F.interpolate(out5, size=size, mode='bilinear', align_corners=False) out = torch.cat([out1, out2, out3, out4, out5], dim=1) out = self.conv1(out) out = self.bn1(out) return self.relu(out) class DeepLabV3(nn.Module): def __init__(self, num_classes, backbone='resnet50'): super(DeepLabV3, self).__init__() # Encoder (backbone) if backbone == 'resnet50': resnet = resnet50(weights=None) # No pretraining self.backbone = nn.Sequential(*list(resnet.children())[:-2]) # Remove FC and AvgPool layers in_channels = 2048 else: raise NotImplementedError("Only ResNet50 is supported as a backbone.") # ASPP module self.aspp = ASPP(in_channels, 256) # Decoder self.decoder = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, num_classes, kernel_size=1, stride=1) ) def forward(self, x): size = x.shape[2:] # Original input size x = self.backbone(x) x = self.aspp(x) x = self.decoder(x) x = F.interpolate(x, size=size, mode='bilinear', align_corners=False) return x num_classes = 34 model = DeepLabV3(num_classes=num_classes).to(device)
Leave a Comment