Untitled
unknown
python
3 years ago
2.6 kB
43
Indexable
import torch from model import * from utils import * import os from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt from torchsummaryX import summary os.environ['KMP_DUPLICATE_LIB_OK']='True' torch.backends.cudnn.benchmark = False torch.cuda.empty_cache() LEARNING_RATE = 1e-4 DEVICE = "cuda" BATCH_SIZE = 4 NUM_EPOCHS = 100 NUM_WORKERS = 2 IMAGE_HEIGHT = 480 IMAGE_WIDTH = 480 N_CLASSES = 16 PIN_MEMORY = True LOAD_MODEL = False CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213), (0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)] def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS): # input_shape = (224, 224, 1) seg_arr = seg_arr[:, :, 0] output_height = seg_arr.shape[0] output_width = seg_arr.shape[1] seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3) for c in range(n_classes): seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8') seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8') seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8') seg_img = seg_img.astype(np.int32) return seg_img def main(): model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE) x = torch.randn((1, 1, 256, 256)).to(DEVICE) summary(model, x, x, x) ''' if LOAD_MODEL: load_checkpoint('model.pth.tar', model, optimizer, LEARNING_RATE) ''' # scaler = torch.cuda.amp.GradScaler() logger = TensorBoardLogger('logs', name='multi_path_UNet') callbacks = [ EarlyStopping(monitor="val_loss", patience=5), ModelCheckpoint(filename='model.ckpt', save_last=True, verbose=True, dirpath='Model/'), ] trainer = Trainer(gpus=torch.cuda.device_count(), max_epochs=100, logger=logger, callbacks=callbacks) trainer.fit(model) if __name__ == '__main__': main()
Editor is loading...