Untitled
unknown
plain_text
10 months ago
4.4 kB
5
Indexable
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1,
padding=dilation, dilation=dilation, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP, self).__init__()
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6, bias=False)
self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12, bias=False)
self.atrous_block4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18, bias=False)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
)
self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
def forward(self, x):
size = x.shape[2:]
out1 = F.relu(self.atrous_block1(x))
out2 = F.relu(self.atrous_block2(x))
out3 = F.relu(self.atrous_block3(x))
out4 = F.relu(self.atrous_block4(x))
out5 = F.interpolate(self.global_avg_pool(x), size=size, mode='bilinear', align_corners=False)
out = torch.cat([out1, out2, out3, out4, out5], dim=1)
out = F.relu(self.bn1(self.conv1(out)))
return out
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecoderBlock, self).__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.upconv(x)
x = self.conv(x)
return x
class SemanticSegmentationNet(nn.Module):
def __init__(self, num_classes):
super(SemanticSegmentationNet, self).__init__()
self.encoder1 = BasicBlock(3, 64, stride=2)
self.encoder2 = BasicBlock(64, 128, stride=2)
self.encoder3 = BasicBlock(128, 256, stride=2)
self.encoder4 = BasicBlock(256, 512, stride=2)
self.aspp = ASPP(512, 512)
self.decoder4 = DecoderBlock(512, 256)
self.decoder3 = DecoderBlock(256, 128)
self.decoder2 = DecoderBlock(128, 64)
self.decoder1 = DecoderBlock(64, 64)
self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
# Encoder
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# ASPP
aspp = self.aspp(e4)
# Decoder
d4 = self.decoder4(aspp)
d3 = self.decoder3(d4 + e3)
d2 = self.decoder2(d3 + e2)
d1 = self.decoder1(d2 + e1)
# Final output
out = self.final_conv(d1)
return out
# Test the model
model = SemanticSegmentationNet(num_classes=34)
x = torch.randn(1, 3, 256, 512) # Example input
output = model(x)
print("Output shape:", output.shape)
Editor is loading...
Leave a Comment