Untitled
unknown
plain_text
3 years ago
2.9 kB
5
Indexable
from monai.transforms import ( AsDiscreted, AsDiscrete, EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, Orientationd, SaveImaged, Invertd, ) from monai.losses.dice import DiceCELoss from monai.networks.nets import UNet from monai.networks.layers import Norm from monai.metrics import DiceMetric from monai.losses import DiceLoss from monai.inferers import sliding_window_inference from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch import torch import os, warnings, monai import numpy as np import nibabel as nib from os import listdir from os.path import isfile, join import matplotlib.pyplot as plt import nibabel as nib import os import glob from monai.optimizers import Novograd from monai.metrics import compute_meandice, DiceMetric from monai.utils import first device = torch.device("cpu") root_dir = "/monai/plavo/2" data_dir = "/data/plavo/norm/test" test_images = sorted(glob.glob(os.path.join(data_dir, ".", "*.nii"))) test_data = [{"image": image} for image in test_images] test_org_transforms = Compose( [ LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), Orientationd(keys=["image"], axcodes="RAS"), CropForegroundd(keys=["image"], source_key="image"), ] ) check_ds = Dataset(data=test_data, transform=test_org_transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") test_org_ds = Dataset(data=test_data, transform=test_org_transforms) test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=1) post_transforms = Compose( [ AsDiscreted(keys="pred", argmax=True, to_onehot=3), SaveImaged( keys="pred", # meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", separate_folder=False, resample=False, squeeze_end_dims=True, ), ] ) n1 = 32 model = UNet( spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) model.load_state_dict(torch.load(os.path.join(root_dir, "plavo.pth"))) optimizer = torch.optim.Adam(model.parameters(), 1e-4) dice_metric = DiceMetric(include_background=False, reduction="mean") model.eval() with torch.no_grad(): for test_data in test_org_loader: test_inputs = test_data["image"].to(device) roi_size = (320,320,320) sw_batch_size = 1 test_data["pred"] = sliding_window_inference( test_inputs, roi_size, sw_batch_size, model) test_data = [post_transforms(i) for i in decollate_batch(test_data)]
Editor is loading...