nord vpnnord vpn
Ad

Untitled

mail@pastecode.io avatar
unknown
plain_text
6 months ago
4.6 kB
2
Indexable
Never
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 8, 2, 3), #16 128 128
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 64, 6, 2, 2), #64 64 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(3, 64, 8, 4, 3), #64 64 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 6, 2, 2), #128 32 32
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), #256 16 16
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 256, 6, 4, 2), #256 16 16
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2,1), #512 8 8
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 1024, 4, 2, 1), #1024 4 4
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
        )
        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 1024, 4, 4, 1), #1024 4 4
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(1024, 4096),
            nn.ReLU(inplace=True),
        )
        self.conv7 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        self.conv8 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        self.conv9 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        self.conv10 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        
        self.brx = nn.Sequential(
            nn.Conv2d(512, 512, 1, 1, 0),
            nn.BatchNorm2d(512), 
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 1, 1, 0),
            nn.BatchNorm2d(512), 
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 300, 1, 1, 0),
            nn.BatchNorm2d(300), 
            nn.ReLU(inplace=True)
        )
        self.bry = nn.Sequential(
            nn.Conv2d(512, 512, 1, 1, 0),
            nn.BatchNorm2d(512), 
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 1, 1, 0),
            nn.BatchNorm2d(512), 
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 100, 1, 1, 0),
            nn.BatchNorm2d(100), 
            nn.ReLU(inplace=True)
        )
        self.conv3d = nn.Sequential(
            nn.Conv3d(4, 4, 3, 1, 1),
            nn.BatchNorm3d(4), 
            nn.Tanh(),
            nn.Conv3d(4, 4, 3, 1, 1),
            nn.BatchNorm3d(4), 
        )

    def forward(self, x):
        h1 = self.conv1(x)
        h2 = self.conv2(x)
        h12 = h1 + h2
        h3 = self.conv3(h12)
        h4 = self.conv4(h12)
        h34 = h3 + h4
        h5 = self.conv5(h34)
        h6 = self.conv6(h34)
        h0 = h5 + h6
        h0 = F.max_pool2d(h0, 4)
        h0 = h0.view(-1, 256, 2, 2)
        h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #256 4 4
        h0 = self.conv7(h0) #512 4 4
        h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 8 8
        h0 = self.conv8(h0) #512 8 8
        h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 16 16
        h0 = self.conv9(h0) #512 16 16
        h0 = F.interpolate(h0, scale_factor=2, mode='bilinear', align_corners = False) #512 32 32
        h0 = self.conv10(h0) #512 32 32
        
        brx = self.brx(h0)
        brx = brx.view(-1, 3, 100, 32, 32)
        bry = self.bry(h0)
        bry = bry.view(-1, 1, 100, 32, 32)
        
        x = [brx, bry]
        x = torch.cat(x, 1)
        
        x = self.conv3d(x)
        x = x.permute(0, 2, 1, 3, 4)
        
        return x  # (batch_size, 100, 4, 32, 32)

nord vpnnord vpn
Ad