import numpy as np
import torch
from tqdm import tqdm
from model import *
from utils import *
import os
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()
LEARNING_RATE = 1e-4
DEVICE = "cuda"
N_CLASSES = 16
CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213),
(0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)]
def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS):
# input_shape = (224, 224, 1)
seg_arr = seg_arr[:, :, 0]
output_height = seg_arr.shape[0]
output_width = seg_arr.shape[1]
seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3)
for c in range(n_classes):
seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8')
seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8')
seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8')
seg_img = seg_img.astype(np.int32)
return seg_img
def main():
with torch.no_grad():
model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE)
model.load_from_checkpoint('Model/last.ckpt', in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16)
'''
X_test = np.load('X_test.npy')
y_test = np.load('y_test.npy')
X_test = X_test.astype(np.float32)
y_test = y_test.astype(np.int32)
print(X_test.shape, y_test.shape, np.max(X_test), np.max(y_test))
print("Classes: ", np.unique(y_test).astype(np.int32))
'''
model.eval()
test_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_test.npy', y_path='y_test.npy')
count = 0
# for img, mask in tqdm(zip(X_test, y_test), total=len(X_test)):
for batch in tqdm(test_dataset):
name = "results/results_" + str(count) + ".png"
count +=1
# flair = img[:, :, 0]
# t2 = img[:, :, 1]
# t1 = img[:, :, 2]
flair, t2, t1, mask = batch
flair = torch.from_numpy(flair).to(DEVICE)
t2 = torch.from_numpy(t2).to(DEVICE)
t1 = torch.from_numpy(t1).to(DEVICE)
# flair = flair.unsqueeze(0).unsqueeze(0).unsqueeze(0)
# t2 = t2.unsqueeze(0).unsqueeze(0).unsqueeze(0)
# t1 = t1.unsqueeze(0).unsqueeze(0).unsqueeze(0)
flair = flair[0, ...] # (16, 1, 256, 256) --> (1, 256, 256)
t2 = t2[0, ...]
t1 = t1[0, ...]
flair = flair.unsqueeze(0).unsqueeze(0)
t2 = t2.unsqueeze(0).unsqueeze(0)
t1 = t1.unsqueeze(0).unsqueeze(0)
# forward pass
with torch.cuda.amp.autocast():
prediciton = model(flair, t2, t1)
# prediciton = torch.exp(prediciton)
prediciton = prediciton.argmax(dim=1).squeeze(0).unsqueeze(-1).cpu().detach().numpy()
mask = np.expand_dims(mask, axis=-1)
print(prediciton.shape, mask.shape)
mask = get_colored_segmentation_image(mask, N_CLASSES)
prediciton = get_colored_segmentation_image(prediciton, N_CLASSES)
flair = flair.squeeze(0).squeeze(0).cpu().numpy()
t2 = t2.squeeze(0).squeeze(0).cpu().numpy()
t1 = t1.squeeze(0).squeeze(0).cpu().numpy()
fig = plt.figure(figsize=(8, 8))
columns = 5
rows = 1
fig.add_subplot(rows, columns, 1)
plt.imshow(flair, cmap='gray')
plt.title("Flair")
plt.axis('off')
fig.add_subplot(rows, columns, 2)
plt.imshow(t2, cmap='gray')
plt.title("T2")
plt.axis('off')
fig.add_subplot(rows, columns, 3)
plt.imshow(t1, cmap='gray')
plt.title("T1")
plt.axis('off')
fig.add_subplot(rows, columns, 4)
plt.imshow(mask)
plt.title("Ground Truth")
plt.axis('off')
fig.add_subplot(rows, columns, 5)
plt.imshow(prediciton)
plt.title("Prediction")
plt.axis('off')
plt.savefig(name)
plt.close()
if __name__ == '__main__':
main()