Untitled
unknown
python
3 years ago
5.0 kB
51
Indexable
import numpy as np import torch from tqdm import tqdm from model import * from utils import * import os import matplotlib.pyplot as plt os.environ['KMP_DUPLICATE_LIB_OK']='True' torch.backends.cudnn.benchmark = False torch.cuda.empty_cache() LEARNING_RATE = 1e-4 DEVICE = "cuda" N_CLASSES = 16 CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213), (0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)] def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS): # input_shape = (224, 224, 1) seg_arr = seg_arr[:, :, 0] output_height = seg_arr.shape[0] output_width = seg_arr.shape[1] seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3) for c in range(n_classes): seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8') seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8') seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8') seg_img = seg_img.astype(np.int32) return seg_img def main(): with torch.no_grad(): model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE) model.load_from_checkpoint('Model/last.ckpt', in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16) ''' X_test = np.load('X_test.npy') y_test = np.load('y_test.npy') X_test = X_test.astype(np.float32) y_test = y_test.astype(np.int32) print(X_test.shape, y_test.shape, np.max(X_test), np.max(y_test)) print("Classes: ", np.unique(y_test).astype(np.int32)) ''' model.eval() test_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_test.npy', y_path='y_test.npy') count = 0 # for img, mask in tqdm(zip(X_test, y_test), total=len(X_test)): for batch in tqdm(test_dataset): name = "results/results_" + str(count) + ".png" count +=1 # flair = img[:, :, 0] # t2 = img[:, :, 1] # t1 = img[:, :, 2] flair, t2, t1, mask = batch flair = torch.from_numpy(flair).to(DEVICE) t2 = torch.from_numpy(t2).to(DEVICE) t1 = torch.from_numpy(t1).to(DEVICE) # flair = flair.unsqueeze(0).unsqueeze(0).unsqueeze(0) # t2 = t2.unsqueeze(0).unsqueeze(0).unsqueeze(0) # t1 = t1.unsqueeze(0).unsqueeze(0).unsqueeze(0) flair = flair[0, ...] # (16, 1, 256, 256) --> (1, 256, 256) t2 = t2[0, ...] t1 = t1[0, ...] flair = flair.unsqueeze(0).unsqueeze(0) t2 = t2.unsqueeze(0).unsqueeze(0) t1 = t1.unsqueeze(0).unsqueeze(0) # forward pass with torch.cuda.amp.autocast(): prediciton = model(flair, t2, t1) # prediciton = torch.exp(prediciton) prediciton = prediciton.argmax(dim=1).squeeze(0).unsqueeze(-1).cpu().detach().numpy() mask = np.expand_dims(mask, axis=-1) print(prediciton.shape, mask.shape) mask = get_colored_segmentation_image(mask, N_CLASSES) prediciton = get_colored_segmentation_image(prediciton, N_CLASSES) flair = flair.squeeze(0).squeeze(0).cpu().numpy() t2 = t2.squeeze(0).squeeze(0).cpu().numpy() t1 = t1.squeeze(0).squeeze(0).cpu().numpy() fig = plt.figure(figsize=(8, 8)) columns = 5 rows = 1 fig.add_subplot(rows, columns, 1) plt.imshow(flair, cmap='gray') plt.title("Flair") plt.axis('off') fig.add_subplot(rows, columns, 2) plt.imshow(t2, cmap='gray') plt.title("T2") plt.axis('off') fig.add_subplot(rows, columns, 3) plt.imshow(t1, cmap='gray') plt.title("T1") plt.axis('off') fig.add_subplot(rows, columns, 4) plt.imshow(mask) plt.title("Ground Truth") plt.axis('off') fig.add_subplot(rows, columns, 5) plt.imshow(prediciton) plt.title("Prediction") plt.axis('off') plt.savefig(name) plt.close() if __name__ == '__main__': main()
Editor is loading...