Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
9.5 kB
1
Indexable
Never
from typing import ForwardRef
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms.functional as TF
from torch.nn import functional as F
from torchsummaryX import summary
from dataset import GaziBrainsDataset
from pytorch_lightning import LightningModule

DEVICE = 'cpu'
LEARNING_RATE = 1e-4

class ResBlock(LightningModule):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        
        self.skip = nn.Sequential()
        
        if in_ch != out_ch:
            self.skip = nn.Sequential( 
                    nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
                    nn.BatchNorm2d(out_ch)
            )
        else:
            self.skip = None
            
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch)
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        out = self.block(x)
        
        if self.skip is not None:
            identity = self.skip(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class Encoder(LightningModule):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
        super(Encoder, self).__init__()

        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for idx in range(len(features) - 1):
            self.downs.append(ResBlock(in_channels, features[idx]))
            in_channels = features[idx]

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        skip_connections = skip_connections[::-1]

        return x, skip_connections

class Decoder(LightningModule):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
        super(Decoder, self).__init__()

        self.ups = nn.ModuleList()

        for idx in reversed(range(len(features) - 1)):
            self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2))
            self.ups.append(ResBlock(features[idx]*2, features[idx]))

    def forward(self, x, skip_connections):
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, skip_connection.shape[2:])

            concat_skip = torch.cat([x, skip_connection], dim=1)

            x = self.ups[idx+1](concat_skip)

        return x

# Pytorch U-Net Model
class UNet(LightningModule):
    def __init__(self, in_channels=1, out_channels=16, features=[16, 32, 64, 128], device=DEVICE):
        super(UNet, self).__init__()
        
        self.encoder1 = Encoder(in_channels, out_channels, features).to(device)
        self.decoder1 = Decoder(in_channels, out_channels, features).to(device)

        self.encoder2 = Encoder(in_channels, out_channels, features).to(device)
        self.decoder2 = Decoder(in_channels, out_channels, features).to(device)

        self.encoder3 = Encoder(in_channels, out_channels, features).to(device)
        self.decoder3 = Decoder(in_channels, out_channels, features).to(device)

        self.bottleneck = ResBlock(features[-2] * 3, features[-1])      
        self.out = nn.Conv2d(features[0]*3, out_channels, kernel_size=1, padding=0, stride=1)
    
    def forward(self, x1, x2, x3):
        x1 = x1.squeeze(0)
        x2 = x2.squeeze(0)
        x3 = x3.squeeze(0)
        
        x1, skip_connections1 = self.encoder1(x1)
        x2, skip_connections2 = self.encoder2(x2)
        x3, skip_connections3 = self.encoder3(x3)

        x = self.bottleneck(torch.cat([x1, x2, x3], dim=1))
        
        x1 = self.decoder1(x, skip_connections1)
        x2 = self.decoder2(x, skip_connections2)
        x3 = self.decoder3(x, skip_connections3)

        x = self.out(torch.cat([x1, x2, x3], dim=1))
        
        return F.log_softmax(x)
    
    def training_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        flair = flair.squeeze(0)
        t2 = t2.squeeze(0)
        t1 = t1.squeeze(0)
        y = y.squeeze(0)
        
        logits = self(flair, t2, t1)
        loss = F.nll_loss(logits, y)
        
        return {'loss': loss, 'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        flair = flair.squeeze(0)
        t2 = t2.squeeze(0)
        t1 = t1.squeeze(0)
        y = y.squeeze(0)
        
        logits = self(flair, t2, t1)
        loss = F.nll_loss(logits, y)
        
        self.log("val_loss", loss, prog_bar=True)
        
        return {'loss': loss, 'log': {'val_loss': loss}}
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), LEARNING_RATE)
    
    ################
    # DATA PREPARATION
    ################
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset')
            self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
        
        if stage == 'test' or stage is None:
            self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2)
    
    def val_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
    
class baseline_UNet(LightningModule):
    def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128], device=DEVICE):
        super(baseline_UNet, self).__init__()
        
        self.encoder = Encoder(in_channels, out_channels, features).to(device)
        self.decoder = Decoder(in_channels, out_channels, features).to(device)
        
        self.bottleneck = ResBlock(features[-2] * 3, features[-1])      
        self.out = nn.Conv2d(features[0]*3, out_channels, kernel_size=1, padding=0, stride=1)
    
    def forward(self, x1, x2, x3):
        x1 = x1.squeeze(0)
        
        x1, skip_connections = self.encoder(x1)

        x1 = self.bottleneck(torch.cat(x1, dim=1))
        
        x1 = self.decoder1(x1, skip_connections)

        x1 = self.out(x1, dim=1)
        
        return F.log_softmax(x1)
    
    def training_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        flair = flair.squeeze(0)
        t2 = t2.squeeze(0)
        t1 = t1.squeeze(0)
        y = y.squeeze(0)
        
        logits = self(flair, t2, t1)
        loss = F.nll_loss(logits, y)
        
        return {'loss': loss, 'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_idx):
        flair, t2, t1, y = batch
        
        flair = flair.squeeze(0)
        t2 = t2.squeeze(0)
        t1 = t1.squeeze(0)
        y = y.squeeze(0)
        
        logits = self(flair, t2, t1)
        loss = F.nll_loss(logits, y)
        
        self.log("val_loss", loss, prog_bar=True)
        
        return {'loss': loss, 'log': {'val_loss': loss}}
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), LEARNING_RATE)
    
    ################
    # DATA PREPARATION
    ################
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset')
            self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
        
        if stage == 'test' or stage is None:
            self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2)
    
    def val_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
    
def test():
    x = torch.randn((1, 1, 256, 256))
    model = UNet(in_channels=1, out_channels=16)
    preds = model(x, x, x)
        
    assert preds.shape == torch.Size([1, 16, 256, 256])
        
if __name__ == "__main__":
    model = UNet(in_channels=1, out_channels=16)
    
    x = torch.randn((1, 1, 256, 256)).to(DEVICE)

    summary(model, x, x, x)
    
    test()