Untitled
unknown
python
3 years ago
9.5 kB
0
Indexable
Never
from typing import ForwardRef import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader import torchvision.transforms.functional as TF from torch.nn import functional as F from torchsummaryX import summary from dataset import GaziBrainsDataset from pytorch_lightning import LightningModule DEVICE = 'cpu' LEARNING_RATE = 1e-4 class ResBlock(LightningModule): def __init__(self, in_ch, out_ch): super().__init__() self.skip = nn.Sequential() if in_ch != out_ch: self.skip = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_ch) ) else: self.skip = None self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=False), nn.BatchNorm2d(out_ch) ) self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = x out = self.block(x) if self.skip is not None: identity = self.skip(x) out += identity out = self.relu(out) return out class Encoder(LightningModule): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]): super(Encoder, self).__init__() self.downs = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) for idx in range(len(features) - 1): self.downs.append(ResBlock(in_channels, features[idx])) in_channels = features[idx] def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) skip_connections = skip_connections[::-1] return x, skip_connections class Decoder(LightningModule): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128, 256]): super(Decoder, self).__init__() self.ups = nn.ModuleList() for idx in reversed(range(len(features) - 1)): self.ups.append(nn.ConvTranspose2d(features[idx+1], features[idx], kernel_size=2, stride=2)) self.ups.append(ResBlock(features[idx]*2, features[idx])) def forward(self, x, skip_connections): for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = TF.resize(x, skip_connection.shape[2:]) concat_skip = torch.cat([x, skip_connection], dim=1) x = self.ups[idx+1](concat_skip) return x # Pytorch U-Net Model class UNet(LightningModule): def __init__(self, in_channels=1, out_channels=16, features=[16, 32, 64, 128], device=DEVICE): super(UNet, self).__init__() self.encoder1 = Encoder(in_channels, out_channels, features).to(device) self.decoder1 = Decoder(in_channels, out_channels, features).to(device) self.encoder2 = Encoder(in_channels, out_channels, features).to(device) self.decoder2 = Decoder(in_channels, out_channels, features).to(device) self.encoder3 = Encoder(in_channels, out_channels, features).to(device) self.decoder3 = Decoder(in_channels, out_channels, features).to(device) self.bottleneck = ResBlock(features[-2] * 3, features[-1]) self.out = nn.Conv2d(features[0]*3, out_channels, kernel_size=1, padding=0, stride=1) def forward(self, x1, x2, x3): x1 = x1.squeeze(0) x2 = x2.squeeze(0) x3 = x3.squeeze(0) x1, skip_connections1 = self.encoder1(x1) x2, skip_connections2 = self.encoder2(x2) x3, skip_connections3 = self.encoder3(x3) x = self.bottleneck(torch.cat([x1, x2, x3], dim=1)) x1 = self.decoder1(x, skip_connections1) x2 = self.decoder2(x, skip_connections2) x3 = self.decoder3(x, skip_connections3) x = self.out(torch.cat([x1, x2, x3], dim=1)) return F.log_softmax(x) def training_step(self, batch, batch_idx): flair, t2, t1, y = batch flair = flair.squeeze(0) t2 = t2.squeeze(0) t1 = t1.squeeze(0) y = y.squeeze(0) logits = self(flair, t2, t1) loss = F.nll_loss(logits, y) return {'loss': loss, 'log': {'train_loss': loss}} def validation_step(self, batch, batch_idx): flair, t2, t1, y = batch flair = flair.squeeze(0) t2 = t2.squeeze(0) t1 = t1.squeeze(0) y = y.squeeze(0) logits = self(flair, t2, t1) loss = F.nll_loss(logits, y) self.log("val_loss", loss, prog_bar=True) return {'loss': loss, 'log': {'val_loss': loss}} def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), LEARNING_RATE) ################ # DATA PREPARATION ################ def setup(self, stage=None): if stage == 'fit' or stage is None: self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset') self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation') if stage == 'test' or stage is None: self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation') def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2) def val_dataloader(self): return DataLoader(self.validation_dataset, batch_size=1, shuffle=True) def test_dataloader(self): return DataLoader(self.validation_dataset, batch_size=1, shuffle=True) class baseline_UNet(LightningModule): def __init__(self, in_channels=3, out_channels=16, features=[16, 32, 64, 128], device=DEVICE): super(baseline_UNet, self).__init__() self.encoder = Encoder(in_channels, out_channels, features).to(device) self.decoder = Decoder(in_channels, out_channels, features).to(device) self.bottleneck = ResBlock(features[-2], features[-1]) self.out = nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0, stride=1) def forward(self, x1, x2, x3): x1 = x1.squeeze(0) x1, skip_connections = self.encoder(x1) x1 = self.bottleneck(x1) x1 = self.decoder(x1, skip_connections) x1 = self.out(x1, dim=1) return F.log_softmax(x1) def training_step(self, batch, batch_idx): flair, t2, t1, y = batch flair = flair.squeeze(0) t2 = t2.squeeze(0) t1 = t1.squeeze(0) y = y.squeeze(0) logits = self(flair, t2, t1) loss = F.nll_loss(logits, y) return {'loss': loss, 'log': {'train_loss': loss}} def validation_step(self, batch, batch_idx): flair, t2, t1, y = batch flair = flair.squeeze(0) t2 = t2.squeeze(0) t1 = t1.squeeze(0) y = y.squeeze(0) logits = self(flair, t2, t1) loss = F.nll_loss(logits, y) self.log("val_loss", loss, prog_bar=True) return {'loss': loss, 'log': {'val_loss': loss}} def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), LEARNING_RATE) ################ # DATA PREPARATION ################ def setup(self, stage=None): if stage == 'fit' or stage is None: self.train_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset') self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation') if stage == 'test' or stage is None: self.validation_dataset = GaziBrainsDataset(root_dir='Data/split_npy_dataset_validation') def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=2) def val_dataloader(self): return DataLoader(self.validation_dataset, batch_size=1, shuffle=True) def test_dataloader(self): return DataLoader(self.validation_dataset, batch_size=1, shuffle=True) def test(): x = torch.randn((1, 1, 256, 256)) model = UNet(in_channels=1, out_channels=16) preds = model(x, x, x) assert preds.shape == torch.Size([1, 16, 256, 256]) if __name__ == "__main__": model = UNet(in_channels=1, out_channels=16) x = torch.randn((1, 1, 256, 256)).to(DEVICE) summary(model, x, x, x) test()