Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
2.6 kB
20
Indexable
Never
import torch
from model import *
from utils import *
import os

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from torchsummaryX import summary

os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()

LEARNING_RATE = 1e-4
DEVICE = "cuda"
BATCH_SIZE = 4 
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 480
IMAGE_WIDTH = 480
N_CLASSES = 16
PIN_MEMORY = True
LOAD_MODEL = False
CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213),
                (0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)]

def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS):
    # input_shape = (224, 224, 1)
    seg_arr = seg_arr[:, :, 0]
    
    output_height = seg_arr.shape[0]
    output_width = seg_arr.shape[1]
    
    seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3)
    
    for c in range(n_classes):        
        seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8')
        seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8')
        seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8') 
    
    seg_img = seg_img.astype(np.int32)
    
    return seg_img

def main():
    model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE)
    
    x = torch.randn((1, 1, 256, 256)).to(DEVICE)

    summary(model, x, x, x)
    
    '''
    if LOAD_MODEL:
        load_checkpoint('model.pth.tar', model, optimizer, LEARNING_RATE)
    '''
        
    # scaler = torch.cuda.amp.GradScaler()
    
    logger = TensorBoardLogger('logs', name='multi_path_UNet')
    
    callbacks = [
                    EarlyStopping(monitor="val_loss", patience=5),
                    ModelCheckpoint(filename='model.ckpt', save_last=True, verbose=True, dirpath='Model/'),
                ]
    
    trainer = Trainer(gpus=torch.cuda.device_count(), 
                      max_epochs=100, 
                      logger=logger, 
                      callbacks=callbacks)
    
    trainer.fit(model)
    

if __name__ == '__main__':
    main()