Untitled

 avatar
unknown
plain_text
2 years ago
6.6 kB
2
Indexable
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandRotated,
    RandFlipd,
    RandScaleIntensityd,
)
from monai.losses.dice import DiceCELoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import matplotlib.pyplot as plt
import os
import glob
from monai.optimizers import Novograd
from monai.metrics import DiceMetric
from monai.utils import set_determinism

CUDA_LAUNCH_BLOCKING = "1"
data_root = "/monai/plavo/2"
path_to_target = "/data/plavo/norm/together"

train_images = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii")))
train_labels = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii.gz")))

data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:173], data_dicts[173:183]
print(len(train_files))
print(len(val_files))

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(320, 320, 320),
            pos=1,
            neg=1,
            num_samples=1,
            image_key="image",
            image_threshold=0,
        ),
        RandAdjustContrastd(keys="image", prob=0.3),
        RandGaussianSmoothd(keys="image", prob=0.3),
        RandRotated(
            keys=["image", "label"], range_x=0.3, range_y=0.4, range_z=0.4, prob=0.4
        ),
        RandFlipd(keys=["image", "label"], prob=0.4, spatial_axis=[0, 1, 2]),
        # Rand3DElasticd(keys=['image', 'label'], sigma_range=[0, 0.1], magnitude_range=[0,1], prob=0.3, rotate_range=[-0.5, 0.5],  mode=('bilinear', 'nearest'),),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.3),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        # RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96,96,96], pos=1, neg=1, num_samples=1),
    ]
)

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")

train_ds = CacheDataset(
    data=train_files, transform=train_transforms, cache_rate=0.2, num_workers=1
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1)
val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=0.2, num_workers=1
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1)

device = torch.device("cuda:0")

n1 = 32
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
model.load_state_dict(torch.load("plavo.pth"))
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

max_epochs = 600
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)])
post_label = Compose([AsDiscrete(to_onehot=3)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"Train_loss: {loss.item():.8f}"
        )
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"Epoch {epoch + 1} Average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (320, 320, 320)
                sw_batch_size = 1
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model
                )
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(data_root, "plavo.pth"))
                print("Saved new best metric model")
            print(
                f"Current epoch: {epoch + 1} Current mean dice: {metric:.4f}"
                f"\nBest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

print(
    f"Training completed, Best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}"
)
Editor is loading...