Untitled
unknown
plain_text
10 months ago
2.6 kB
5
Indexable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP, self).__init__()
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1)
self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6)
self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12)
self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
)
self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
size = x.shape[2:]
out1 = self.atrous_block1(x)
out2 = self.atrous_block2(x)
out3 = self.atrous_block3(x)
out4 = self.atrous_block4(x)
out5 = self.global_avg_pool(x)
out5 = F.interpolate(out5, size=size, mode='bilinear', align_corners=False)
out = torch.cat([out1, out2, out3, out4, out5], dim=1)
out = self.conv1(out)
out = self.bn1(out)
return self.relu(out)
class DeepLabV3(nn.Module):
def __init__(self, num_classes, backbone='resnet50'):
super(DeepLabV3, self).__init__()
# Encoder (backbone)
if backbone == 'resnet50':
resnet = resnet50(weights=None) # No pretraining
self.backbone = nn.Sequential(*list(resnet.children())[:-2]) # Remove FC and AvgPool layers
in_channels = 2048
else:
raise NotImplementedError("Only ResNet50 is supported as a backbone.")
# ASPP module
self.aspp = ASPP(in_channels, 256)
# Decoder
self.decoder = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
)
def forward(self, x):
size = x.shape[2:] # Original input size
x = self.backbone(x)
x = self.aspp(x)
x = self.decoder(x)
x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
return x
num_classes = 34
model = DeepLabV3(num_classes=num_classes).to(device)Editor is loading...
Leave a Comment