Untitled
unknown
plain_text
2 years ago
6.6 kB
2
Indexable
from monai.utils import first, set_determinism from monai.transforms import ( AsDiscrete, EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, RandCropByPosNegLabeld, RandAdjustContrastd, RandGaussianSmoothd, RandRotated, RandFlipd, RandScaleIntensityd, ) from monai.losses.dice import DiceCELoss from monai.networks.nets import UNet from monai.networks.layers import Norm from monai.metrics import DiceMetric from monai.losses import DiceLoss from monai.inferers import sliding_window_inference from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch import torch import matplotlib.pyplot as plt import os import glob from monai.optimizers import Novograd from monai.metrics import DiceMetric from monai.utils import set_determinism CUDA_LAUNCH_BLOCKING = "1" data_root = "/monai/plavo/2" path_to_target = "/data/plavo/norm/together" train_images = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii"))) train_labels = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii.gz"))) data_dicts = [ {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] train_files, val_files = data_dicts[:173], data_dicts[173:183] print(len(train_files)) print(len(val_files)) train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(320, 320, 320), pos=1, neg=1, num_samples=1, image_key="image", image_threshold=0, ), RandAdjustContrastd(keys="image", prob=0.3), RandGaussianSmoothd(keys="image", prob=0.3), RandRotated( keys=["image", "label"], range_x=0.3, range_y=0.4, range_z=0.4, prob=0.4 ), RandFlipd(keys=["image", "label"], prob=0.4, spatial_axis=[0, 1, 2]), # Rand3DElasticd(keys=['image', 'label'], sigma_range=[0, 0.1], magnitude_range=[0,1], prob=0.3, rotate_range=[-0.5, 0.5], mode=('bilinear', 'nearest'),), RandScaleIntensityd(keys="image", factors=0.1, prob=0.3), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96,96,96], pos=1, neg=1, num_samples=1), ] ) check_ds = Dataset(data=val_files, transform=val_transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image, label = (check_data["image"][0][0], check_data["label"][0][0]) print(f"image shape: {image.shape}, label shape: {label.shape}") train_ds = CacheDataset( data=train_files, transform=train_transforms, cache_rate=0.2, num_workers=1 ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1) val_ds = CacheDataset( data=val_files, transform=val_transforms, cache_rate=0.2, num_workers=1 ) val_loader = DataLoader(val_ds, batch_size=1, num_workers=1) device = torch.device("cuda:0") n1 = 32 model = UNet( spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) model.load_state_dict(torch.load("plavo.pth")) optimizer = torch.optim.Adam(model.parameters(), 1e-4) dice_metric = DiceMetric(include_background=False, reduction="mean") max_epochs = 600 val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)]) post_label = Compose([AsDiscrete(to_onehot=3)]) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"Train_loss: {loss.item():.8f}" ) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"Epoch {epoch + 1} Average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) roi_size = (320, 320, 320) sw_batch_size = 1 val_outputs = sliding_window_inference( val_inputs, roi_size, sw_batch_size, model ) val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] val_labels = [post_label(i) for i in decollate_batch(val_labels)] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), os.path.join(data_root, "plavo.pth")) print("Saved new best metric model") print( f"Current epoch: {epoch + 1} Current mean dice: {metric:.4f}" f"\nBest mean dice: {best_metric:.4f} " f"at epoch: {best_metric_epoch}" ) print( f"Training completed, Best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}" )
Editor is loading...