Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
6.0 kB
38
Indexable
Never
from typing import ForwardRef
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms.functional as TF
from torch.nn import functional as F
from torchsummaryX import summary
from dataset import GaziBrainsDataset
from pytorch_lightning import LightningModule

DEVICE = 'cpu'
LEARNING_RATE = 1e-4

class ResBlock(LightningModule):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        
        self.skip = nn.Sequential()
        
        if in_ch != out_ch:
            self.skip = nn.Sequential( 
                    nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
                    nn.BatchNorm2d(out_ch)
            )
        else:
            self.skip = None
            
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch)
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        out = self.block(x)
        
        if self.skip is not None:
            identity = self.skip(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class Encoder(LightningModule):
    def __init__(self, in_channels, features):
        super(Encoder, self).__init__()

        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for idx in range(len(features) - 1):
            self.downs.append(ResBlock(in_channels, features[idx]))
            in_channels = features[idx]

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        skip_connections = skip_connections[::-1]

        return x, skip_connections

class Decoder(LightningModule):
    def __init__(self, features):
        super(Decoder, self).__init__()

        self.ups = nn.ModuleList()

        for idx in reversed(range(len(features) - 1)):
            self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2))
            self.ups.append(ResBlock(features[idx]*2, features[idx]))

    def forward(self, x, skip_connections):
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, skip_connection.shape[2:])

            concat_skip = torch.cat([x, skip_connection], dim=1)

            x = self.ups[idx+1](concat_skip)

        return x

# Pytorch U-Net Model
class UNet(LightningModule):
    def __init__(self, in_channels, out_channels, features, batch_size):
        super(UNet, self).__init__()
        
        self.batch_size = batch_size
        
        self.encoder1 = Encoder(in_channels, features)
        self.decoder1 = Decoder(features)

        self.encoder2 = Encoder(in_channels, features)
        self.decoder2 = Decoder(features)

        self.encoder3 = Encoder(in_channels, features)
        self.decoder3 = Decoder(features)

        self.bottleneck = ResBlock(features[-2] * 3, features[-1])      
        self.out = nn.Conv2d(features[0]*3, out_channels, kernel_size=1, padding=0, stride=1)
    
    def forward(self, x1, x2, x3):        
        x1, skip_connections1 = self.encoder1(x1)
        x2, skip_connections2 = self.encoder2(x2)
        x3, skip_connections3 = self.encoder3(x3)

        x = self.bottleneck(torch.cat([x1, x2, x3], dim=1))
        
        x1 = self.decoder1(x, skip_connections1)
        x2 = self.decoder2(x, skip_connections2)
        x3 = self.decoder3(x, skip_connections3)

        x = self.out(torch.cat([x1, x2, x3], dim=1))
        
        return F.log_softmax(x)
    
    def training_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        logits = self(flair, t2, t1)
        loss = F.nll_loss(logits, y.long())
        
        return {'loss': loss, 'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        logits = self(flair, t2, t1)        
        loss = F.nll_loss(logits, y.long())
        
        self.log("val_loss", loss, prog_bar=True)
        
        return {'loss': loss, 'log': {'val_loss': loss}}
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), LEARNING_RATE)
    
    ################
    # DATA PREPARATION
    ################
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_train.npy', y_path='y_train.npy')
            self.validation_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_validation.npy', y_path='y_validation.npy')
        
        if stage == 'test' or stage is None:
            self.validation_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_validation.npy', y_path='y_validation.npy')

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)
    
    def val_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=self.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=self.batch_size, shuffle=True)