Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
6.5 kB
0
Indexable
Never
# Modelzoo for usage 
# Feel free to add any model you like for your final result
# Note : Pretrained model is allowed iff it pretrained on ImageNet

import torch
import torch.nn as nn
from torchvision import models


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


class myLeNet(nn.Module):
    def __init__(self, num_out):
        super(myLeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3,6,kernel_size=5, stride=1),
                             nn.ReLU(),
                             nn.MaxPool2d(kernel_size=2, stride=2),
                             )
        self.conv2 = nn.Sequential(nn.Conv2d(6,16,kernel_size=5),
                             nn.ReLU(),
                             nn.MaxPool2d(kernel_size=2, stride=2),)
        
        self.fc1 = nn.Sequential(nn.Linear(400, 120), nn.ReLU())
        self.fc2 = nn.Sequential(nn.Linear(120,84), nn.ReLU())
        self.fc3 = nn.Linear(84,num_out)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        
        # It is important to check your shape here so that you know how manys nodes are there in first FC in_features
        #print(x.shape)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)        
        out = x
        return out

    
    
class residual_block(nn.Module):
    def __init__(self, in_channels):
        super(residual_block, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(in_channels))

        self.relu = nn.ReLU()
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        return x
        ## TO DO ## 
        # Perform residaul network. 
        # You can refer to our ppt to build the block. It's ok if you want to do much more complicated one. 
        # i.e. pass identity to final result before activation function 

class resnet(nn.Module):
    def __init__(self, in_channels=3, num_out=10):
        super(resnet, self).__init__()
        self.pretrained = models.resnet18(pretrained=True)
        self.pretrained.fc = nn.Sequential(nn.Linear(512, 256, bias=False),
                                            nn.ReLU(),
                                            nn.Linear(256, 10, bias=False))
        self.residual_block = residual_block(512)
        self.dropout = nn.Dropout(0.4)
    def forward(self, x):
        x = self.pretrained.conv1(x)
        x = self.dropout(x)
        x = self.pretrained.bn1(x)
        x = self.pretrained.relu(x)
        x = self.pretrained.maxpool(x)
        x = self.dropout(x)
        x = self.pretrained.layer1(x)
        x = self.pretrained.layer2(x)
        x = self.dropout(x)
        x = self.pretrained.layer3(x)
        x = self.dropout(x)
        # print("here1",x.shape)
        x = self.pretrained.layer4(x)
        x = self.dropout(x)
        x = self.residual_block(x)
        # print("here2",x.shape)
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = self.pretrained.fc(x)

        
        

        return x
        
        
class myResnet(nn.Module):
    def __init__(self, in_channels=3, num_out=10):
        super(myResnet, self).__init__()

        self.stem_conv = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, padding=1)
        self.residual1 = residual_block(64)
        self.cnn2 = nn.Sequential(  nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                                    nn.MaxPool2d(2, 2, 0, 1),
                                    nn.BatchNorm2d(64),
                                    nn.RReLU(),)
        self.residual2 = residual_block(64)

        self.cnn3 =  nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                                    nn.MaxPool2d(2, 2, 0, 1),
                                    nn.BatchNorm2d(128),
                                    nn.RReLU(),
                                )
        self.residual3 = residual_block(128)
        self.cnn4 =  nn.Sequential( nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
                                    nn.MaxPool2d(2, 2, 0, 1),
                                    nn.BatchNorm2d(256),
                                    nn.RReLU(),
                                )
        self.residual4 = residual_block(256)
        self.cnn5 =  nn.Sequential( nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                    nn.MaxPool2d(2, 2, 0, 1),
                                    nn.BatchNorm2d(256),
                                    nn.RReLU(),
                                )
        self.residual5 = residual_block(256)
        self.cnn6 =  nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
                                    nn.MaxPool2d(2, 2, 0, 1),
                                    nn.BatchNorm2d(512),
                                    nn.RReLU(),
                                )
        self.residual6 = residual_block(512)
        self.fc = nn.Sequential(nn.Linear(8192, 1024),
                                nn.ReLU(),
                                nn.Dropout(0.3),
                                nn.Linear(1024, 10),)
        self.dropout = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.3)
        
        
    def forward(self,x):
        x = self.stem_conv(x)
        x = self.dropout2(x)
        x = self.residual1(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.cnn3(x)
        x = self.residual3(x)
        x = self.dropout(x)
        # x = self.cnn4(x)
        # x = self.residual4(x)
        # x = self.dropout(x)
        # x = self.cnn5(x)
        # x = self.residual5(x)
        # x = self.cnn6(x)
        # x = self.residual6(x)
        # print(x.shape)
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = self.fc(x)
        return x