from typing import ForwardRef
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms.functional as TF
from torch.nn import functional as F
from torchsummaryX import summary
from dataset import GaziBrainsDataset
from pytorch_lightning import LightningModule
DEVICE = 'cpu'
LEARNING_RATE = 1e-4
class ResBlock(LightningModule):
def __init__(self, in_ch, out_ch):
super().__init__()
self.skip = nn.Sequential()
if in_ch != out_ch:
self.skip = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch)
)
else:
self.skip = None
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_ch)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.block(x)
if self.skip is not None:
identity = self.skip(x)
out += identity
out = self.relu(out)
return out
class Encoder(LightningModule):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
super(Encoder, self).__init__()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
for idx in range(len(features) - 1):
self.downs.append(ResBlock(in_channels, features[idx]))
in_channels = features[idx]
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
skip_connections = skip_connections[::-1]
return x, skip_connections
class Decoder(LightningModule):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]):
super(Decoder, self).__init__()
self.ups = nn.ModuleList()
for idx in reversed(range(len(features) - 1)):
self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2))
self.ups.append(ResBlock(features[idx]*2, features[idx]))
def forward(self, x, skip_connections):
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, skip_connection.shape[2:])
concat_skip = torch.cat([x, skip_connection], dim=1)
x = self.ups[idx+1](concat_skip)
return x
# Pytorch U-Net Model
class UNet(LightningModule):
def __init__(self, in_channels=1, out_channels=16, features=[16, 32, 64, 128], device=DEVICE):
super(UNet, self).__init__()
self.encoder1 = Encoder(in_channels, out_channels, features).to(device)
self.decoder1 = Decoder(in_channels, out_channels, features).to(device)
self.encoder2 = Encoder(in_channels, out_channels, features).to(device)
self.decoder2 = Decoder(in_channels, out_channels, features).to(device)
self.encoder3 = Encoder(in_channels, out_channels, features).to(device)
self.decoder3 = Decoder(in_channels, out_channels, features).to(device)
self.bottleneck = ResBlock(features[-2] * 3, features[-1])
self.out = nn.Conv2d(features[0]*3, out_channels, kernel_size=1, padding=0, stride=1)
def forward(self, x1, x2, x3):
x1 = x1.squeeze(0)
x2 = x2.squeeze(0)
x3 = x3.squeeze(0)
x1, skip_connections1 = self.encoder1(x1)
x2, skip_connections2 = self.encoder2(x2)
x3, skip_connections3 = self.encoder3(x3)
x = self.bottleneck(torch.cat([x1, x2, x3], dim=1))
x1 = self.decoder1(x, skip_connections1)
x2 = self.decoder2(x, skip_connections2)
x3 = self.decoder3(x, skip_connections3)
x = self.out(torch.cat([x1, x2, x3], dim=1))
return F.log_softmax(x)
def training_step(self, batch, batch_idx):
flair, t2, t1, y = batch
flair = flair.squeeze(0)
t2 = t2.squeeze(0)
t1 = t1.squeeze(0)
y = y.squeeze(0)
logits = self(flair, t2, t1)
loss = F.nll_loss(logits, y)
return {'loss': loss, 'log': {'train_loss': loss}}
def validation_step(self, batch, batch_idx):
flair, t2, t1, y = batch
flair = flair.squeeze(0)
t2 = t2.squeeze(0)
t1 = t1.squeeze(0)
y = y.squeeze(0)
logits = self(flair, t2, t1)
loss = F.nll_loss(logits, y)
self.log("val_loss", loss, prog_bar=True)
return {'loss': loss, 'log': {'val_loss': loss}}
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), LEARNING_RATE)
################
# DATA PREPARATION
################
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset')
self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
if stage == 'test' or stage is None:
self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2)
def val_dataloader(self):
return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
def test_dataloader(self):
return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
class baseline_UNet(LightningModule):
def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128], device=DEVICE):
super(baseline_UNet, self).__init__()
self.encoder = Encoder(in_channels, out_channels, features).to(device)
self.decoder = Decoder(in_channels, out_channels, features).to(device)
self.bottleneck = ResBlock(features[-2], features[-1])
self.out = nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0, stride=1)
def forward(self, x1, x2, x3):
x1 = x1.squeeze(0)
x1, skip_connections = self.encoder(x1)
x1 = self.bottleneck(x1)
x1 = self.decoder(x1, skip_connections)
x1 = self.out(x1, dim=1)
return F.log_softmax(x1)
def training_step(self, batch, batch_idx):
flair, t2, t1, y = batch
flair = flair.squeeze(0)
t2 = t2.squeeze(0)
t1 = t1.squeeze(0)
y = y.squeeze(0)
logits = self(flair, t2, t1)
loss = F.nll_loss(logits, y)
return {'loss': loss, 'log': {'train_loss': loss}}
def validation_step(self, batch, batch_idx):
flair, t2, t1, y = batch
flair = flair.squeeze(0)
t2 = t2.squeeze(0)
t1 = t1.squeeze(0)
y = y.squeeze(0)
logits = self(flair, t2, t1)
loss = F.nll_loss(logits, y)
self.log("val_loss", loss, prog_bar=True)
return {'loss': loss, 'log': {'val_loss': loss}}
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), LEARNING_RATE)
################
# DATA PREPARATION
################
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset')
self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
if stage == 'test' or stage is None:
self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2)
def val_dataloader(self):
return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
def test_dataloader(self):
return DataLoader(self.validation_dataset, batch_size=1, shuffle=True)
def test():
x = torch.randn((1, 1, 256, 256))
model = UNet(in_channels=1, out_channels=16)
preds = model(x, x, x)
assert preds.shape == torch.Size([1, 16, 256, 256])
if __name__ == "__main__":
model = UNet(in_channels=1, out_channels=16)
x = torch.randn((1, 1, 256, 256)).to(DEVICE)
summary(model, x, x, x)
test()