Untitled
unknown
plain_text
6 months ago
4.6 kB
2
Indexable
Never
import os import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 16, 8, 2, 3), #16 128 128 nn.BatchNorm2d(16), nn.ReLU(inplace=True), nn.Conv2d(16, 64, 6, 2, 2), #64 64 64 nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(3, 64, 8, 4, 3), #64 64 64 nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.conv3 = nn.Sequential( nn.Conv2d(64, 128, 6, 2, 2), #128 32 32 nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 256, 4, 2, 1), #256 16 16 nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) self.conv4 = nn.Sequential( nn.Conv2d(64, 256, 6, 4, 2), #256 16 16 nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) self.conv5 = nn.Sequential( nn.Conv2d(256, 512, 4, 2,1), #512 8 8 nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 1024, 4, 2, 1), #1024 4 4 nn.BatchNorm2d(1024), nn.ReLU(inplace=True), ) self.conv6 = nn.Sequential( nn.Conv2d(256, 1024, 4, 4, 1), #1024 4 4 nn.BatchNorm2d(1024), nn.ReLU(inplace=True), ) self.fc1 = nn.Sequential( nn.Linear(1024, 4096), nn.ReLU(inplace=True), ) self.conv7 = nn.Sequential( nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.conv8 = nn.Sequential( nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.conv9 = nn.Sequential( nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.conv10 = nn.Sequential( nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.brx = nn.Sequential( nn.Conv2d(512, 512, 1, 1, 0), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 1, 1, 0), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 300, 1, 1, 0), nn.BatchNorm2d(300), nn.ReLU(inplace=True) ) self.bry = nn.Sequential( nn.Conv2d(512, 512, 1, 1, 0), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 1, 1, 0), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 100, 1, 1, 0), nn.BatchNorm2d(100), nn.ReLU(inplace=True) ) self.conv3d = nn.Sequential( nn.Conv3d(4, 4, 3, 1, 1), nn.BatchNorm3d(4), nn.Tanh(), nn.Conv3d(4, 4, 3, 1, 1), nn.BatchNorm3d(4), ) def forward(self, x): h1 = self.conv1(x) h2 = self.conv2(x) h12 = h1 + h2 h3 = self.conv3(h12) h4 = self.conv4(h12) h34 = h3 + h4 h5 = self.conv5(h34) h6 = self.conv6(h34) h0 = h5 + h6 h0 = F.max_pool2d(h0, 4) h0 = h0.view(-1, 256, 2, 2) h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #256 4 4 h0 = self.conv7(h0) #512 4 4 h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 8 8 h0 = self.conv8(h0) #512 8 8 h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 16 16 h0 = self.conv9(h0) #512 16 16 h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 32 32 h0 = self.conv10(h0) #512 32 32 brx = self.brx(h0) brx = brx.view(-1, 3, 100, 32, 32) bry = self.bry(h0) bry = bry.view(-1, 1, 100, 32, 32) x = [brx, bry] x = torch.cat(x, 1) x = self.conv3d(x) x = x.permute(0, 2, 1, 3, 4) return x # (batch_size, 100, 4, 32, 32)