import torch
from model import *
from utils import *
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torchsummaryX import summary
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()
LEARNING_RATE = 1e-4
DEVICE = "cuda"
BATCH_SIZE = 4
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 480
IMAGE_WIDTH = 480
N_CLASSES = 16
PIN_MEMORY = True
LOAD_MODEL = False
CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213),
(0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)]
def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS):
# input_shape = (224, 224, 1)
seg_arr = seg_arr[:, :, 0]
output_height = seg_arr.shape[0]
output_width = seg_arr.shape[1]
seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3)
for c in range(n_classes):
seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8')
seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8')
seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8')
seg_img = seg_img.astype(np.int32)
return seg_img
def main():
model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE)
x = torch.randn((1, 1, 256, 256)).to(DEVICE)
summary(model, x, x, x)
'''
if LOAD_MODEL:
load_checkpoint('model.pth.tar', model, optimizer, LEARNING_RATE)
'''
# scaler = torch.cuda.amp.GradScaler()
logger = TensorBoardLogger('logs', name='multi_path_UNet')
callbacks = [
EarlyStopping(monitor="val_loss", patience=5),
ModelCheckpoint(filename='model.ckpt', save_last=True, verbose=True, dirpath='Model/'),
]
trainer = Trainer(gpus=torch.cuda.device_count(),
max_epochs=100,
logger=logger,
callbacks=callbacks)
trainer.fit(model)
if __name__ == '__main__':
main()