Untitled
unknown
plain_text
3 years ago
6.6 kB
4
Indexable
from monai.utils import first, set_determinism
from monai.transforms import (
AsDiscrete,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
RandCropByPosNegLabeld,
RandAdjustContrastd,
RandGaussianSmoothd,
RandRotated,
RandFlipd,
RandScaleIntensityd,
)
from monai.losses.dice import DiceCELoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import matplotlib.pyplot as plt
import os
import glob
from monai.optimizers import Novograd
from monai.metrics import DiceMetric
from monai.utils import set_determinism
CUDA_LAUNCH_BLOCKING = "1"
data_root = "/monai/plavo/2"
path_to_target = "/data/plavo/norm/together"
train_images = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii")))
train_labels = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:173], data_dicts[173:183]
print(len(train_files))
print(len(val_files))
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
CropForegroundd(keys=["image", "label"], source_key="image"),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(320, 320, 320),
pos=1,
neg=1,
num_samples=1,
image_key="image",
image_threshold=0,
),
RandAdjustContrastd(keys="image", prob=0.3),
RandGaussianSmoothd(keys="image", prob=0.3),
RandRotated(
keys=["image", "label"], range_x=0.3, range_y=0.4, range_z=0.4, prob=0.4
),
RandFlipd(keys=["image", "label"], prob=0.4, spatial_axis=[0, 1, 2]),
# Rand3DElasticd(keys=['image', 'label'], sigma_range=[0, 0.1], magnitude_range=[0,1], prob=0.3, rotate_range=[-0.5, 0.5], mode=('bilinear', 'nearest'),),
RandScaleIntensityd(keys="image", factors=0.1, prob=0.3),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
CropForegroundd(keys=["image", "label"], source_key="image"),
# RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96,96,96], pos=1, neg=1, num_samples=1),
]
)
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
train_ds = CacheDataset(
data=train_files, transform=train_transforms, cache_rate=0.2, num_workers=1
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1)
val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=0.2, num_workers=1
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1)
device = torch.device("cuda:0")
n1 = 32
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
model.load_state_dict(torch.load("plavo.pth"))
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
max_epochs = 600
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)])
post_label = Compose([AsDiscrete(to_onehot=3)])
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f"{step}/{len(train_ds) // train_loader.batch_size}, "
f"Train_loss: {loss.item():.8f}"
)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"Epoch {epoch + 1} Average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (320, 320, 320)
sw_batch_size = 1
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(data_root, "plavo.pth"))
print("Saved new best metric model")
print(
f"Current epoch: {epoch + 1} Current mean dice: {metric:.4f}"
f"\nBest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
print(
f"Training completed, Best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
Editor is loading...