Untitled

 avatar
unknown
python
3 years ago
5.0 kB
48
Indexable
import numpy as np
import torch
from tqdm import tqdm
from model import *
from utils import *
import os
import matplotlib.pyplot as plt

os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()

LEARNING_RATE = 1e-4
DEVICE = "cuda"
N_CLASSES = 16

CLASS_COLORS = [(0,0,0),(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255), (255,0,255), (255,239,213),
                (0,0,205),(205,133,63),(210,180,140),(102,205,170),(0,0,128),(7,36,0),(218,219,112),(218,112,214)]

def get_colored_segmentation_image(seg_arr, n_classes, colors=CLASS_COLORS):
    # input_shape = (224, 224, 1)
    seg_arr = seg_arr[:, :, 0]
    
    output_height = seg_arr.shape[0]
    output_width = seg_arr.shape[1]
    
    seg_img = np.zeros((output_height, output_width, 3)) # (128, 128, 32, 3)
    
    for c in range(n_classes):        
        seg_img[:, :, 0] += ((seg_arr[:, :] == c) * (colors[c][0])).astype('uint8')
        seg_img[:, :, 1] += ((seg_arr[:, :] == c) * (colors[c][1])).astype('uint8')
        seg_img[:, :, 2] += ((seg_arr[:, :] == c) * (colors[c][2])).astype('uint8') 
    
    seg_img = seg_img.astype(np.int32)
    
    return seg_img
        
def main():
    with torch.no_grad():
        model = UNet(in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16).to(DEVICE)
        model.load_from_checkpoint('Model/last.ckpt', in_channels=1, out_channels=N_CLASSES, features=[16, 32, 64, 128, 192, 256], batch_size=16)

        '''
        X_test = np.load('X_test.npy')
        y_test = np.load('y_test.npy')
        
        X_test = X_test.astype(np.float32)
        y_test = y_test.astype(np.int32)
        
        
        print(X_test.shape, y_test.shape, np.max(X_test), np.max(y_test))
        print("Classes: ", np.unique(y_test).astype(np.int32))
        '''
        
        model.eval()
            
        test_dataset = GaziBrainsDataset(dataset_path='Data/npy_dataset', X_path='X_test.npy', y_path='y_test.npy')

            
        count = 0
        # for img, mask in tqdm(zip(X_test, y_test), total=len(X_test)):
        for batch in tqdm(test_dataset): 
            name = "results/results_" + str(count) + ".png"
            count +=1
            
            # flair = img[:, :, 0]
            # t2 = img[:, :, 1]
            # t1 = img[:, :, 2]

            flair, t2, t1, mask = batch
            
            flair = torch.from_numpy(flair).to(DEVICE)
            t2 = torch.from_numpy(t2).to(DEVICE)
            t1 = torch.from_numpy(t1).to(DEVICE)
            
            # flair = flair.unsqueeze(0).unsqueeze(0).unsqueeze(0)
            # t2 = t2.unsqueeze(0).unsqueeze(0).unsqueeze(0)
            # t1 = t1.unsqueeze(0).unsqueeze(0).unsqueeze(0)       
            
            flair = flair[0, ...] # (16, 1, 256, 256) --> (1, 256, 256)
            t2 = t2[0, ...]
            t1 = t1[0, ...]
            
            flair = flair.unsqueeze(0).unsqueeze(0)
            t2 = t2.unsqueeze(0).unsqueeze(0)
            t1 = t1.unsqueeze(0).unsqueeze(0)
            
            # forward pass
            with torch.cuda.amp.autocast():
                prediciton = model(flair, t2, t1)
            
            # prediciton = torch.exp(prediciton)
        
            prediciton = prediciton.argmax(dim=1).squeeze(0).unsqueeze(-1).cpu().detach().numpy()
            
            mask = np.expand_dims(mask, axis=-1)

            print(prediciton.shape, mask.shape)
            mask = get_colored_segmentation_image(mask, N_CLASSES)
            prediciton = get_colored_segmentation_image(prediciton, N_CLASSES)            
            
            
            flair = flair.squeeze(0).squeeze(0).cpu().numpy()
            t2 = t2.squeeze(0).squeeze(0).cpu().numpy()
            t1 = t1.squeeze(0).squeeze(0).cpu().numpy()
            
            fig = plt.figure(figsize=(8, 8))
            columns = 5
            rows = 1
            
            fig.add_subplot(rows, columns, 1)
            
            plt.imshow(flair, cmap='gray')
            plt.title("Flair")
            plt.axis('off')
            
            fig.add_subplot(rows, columns, 2)
            plt.imshow(t2, cmap='gray')
            plt.title("T2")
            plt.axis('off')
            
            fig.add_subplot(rows, columns, 3)
            plt.imshow(t1, cmap='gray')
            plt.title("T1")
            plt.axis('off')
            
            fig.add_subplot(rows, columns, 4)
            plt.imshow(mask)
            plt.title("Ground Truth")
            plt.axis('off')
            
            fig.add_subplot(rows, columns, 5)
            plt.imshow(prediciton)
            plt.title("Prediction")
            plt.axis('off')
            
            
            
            plt.savefig(name)
            plt.close()
        
    

if __name__ == '__main__':
    main()