Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
2.9 kB
2
Indexable
Never
from monai.transforms import (
    AsDiscreted,
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    Invertd,
)
from monai.losses.dice import DiceCELoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import os, warnings, monai
import numpy as np
import nibabel as nib
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import nibabel as nib
import os
import glob
from monai.optimizers import Novograd
from monai.metrics import compute_meandice, DiceMetric
from monai.utils import first


device = torch.device("cpu")
root_dir = "/monai/plavo/2"
data_dir = "/data/plavo/norm/test"
test_images = sorted(glob.glob(os.path.join(data_dir, ".", "*.nii")))
test_data = [{"image": image} for image in test_images]

test_org_transforms = Compose(
    [
        LoadImaged(keys="image"),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        CropForegroundd(keys=["image"], source_key="image"),
    ]
)

check_ds = Dataset(data=test_data, transform=test_org_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image = (check_data["image"][0][0])
print(f"image shape: {image.shape}")

test_org_ds = Dataset(data=test_data, transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=1)

post_transforms = Compose(
    [
        AsDiscreted(keys="pred", argmax=True, to_onehot=3),
        SaveImaged(
            keys="pred",
            # meta_keys="pred_meta_dict",
            output_dir="./out",
            output_postfix="seg",
            separate_folder=False,
            resample=False,
            squeeze_end_dims=True,
        ),
    ]
)

n1 = 32
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
model.load_state_dict(torch.load(os.path.join(root_dir, "plavo.pth")))
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")


model.eval()
with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs = test_data["image"].to(device)
        roi_size = (320,320,320)
        sw_batch_size = 1
        test_data["pred"] = sliding_window_inference(
            test_inputs, roi_size, sw_batch_size, model)

        test_data = [post_transforms(i) for i in decollate_batch(test_data)]