Untitled

 avatar
unknown
python
4 years ago
9.5 kB
6
Indexable
import torch
import numpy as np
import glob
from sklearn.manifold import TSNE
import plotly.express as px
import plotly.graph_objects as go
from sklearn.manifold import TSNE
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import time
import plotly.figure_factory as ff


models =["customnet", "resnet18_scratch", "resnet50_scratch", "resnet18_pretrained", "resnet50_pretrained"]
models =["customnet"]

dataset = "lungs"
V = "18"
logs_folder = "logs_V"+V

colors = ["darkorange","yellow","green","blue","black","red"]
# 2D/3D Proejctions ( SIMCLR, SHIFT)

D2 = True # if False 3D projections are done
PNH = True # if False SHIFT features are enabled
Nomalized = False

#groups = {"150":"Emphysema 5-10", "300":"Emphysema 10-15", "450":"Emphysema 15-20", "600":"Emphysema 20-100" }
                             
groups = {"150":"Emphysema 0.1-0.5","300":"Emphysema 0.5-1","450":"Emphysema 1-2","600":"Emphysema 2-3", 
          "750":"Emphysema 3-5", "900":"Emphysema 5-10", "1050":"Emphysema 10-15", "1200":"Emphysema 15-20",
          "1350":"Emphysema 20-100"}

tsne_imgs = {}

for model in models:
    f_name_train = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_train_f_name.pth".format(logs_folder,dataset,model, mode, str(class_idx), dataset)
    f_name_one_class = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_f_name.pth".format(logs_folder,dataset,model, mode, str(class_idx), dataset)
    
    if PNH and model not in ["resnet50_pretrained"]:
        f_size = 512
        path_mode_train = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_train_penultimate.pth".format(logs_folder,dataset,model, mode, str(class_idx), dataset)
        path_mode_one_class = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_penultimate.pth".format(logs_folder,dataset,model, mode,str(class_idx), dataset)
    else:
        print("we are here")
        f_size = 2048
        path_mode_train = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_train_penultimate.pth".format(logs_folder,dataset,model, mode, str(class_idx), dataset)
        path_mode_one_class = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_{}_penultimate.pth".format(logs_folder,dataset,model, mode,str(class_idx), dataset)  
    
    for i in range(0,1):
        F = torch.zeros((100, f_size))
        L = [99]*100
        f_names_ = [99]*100
        for idx in range(number_of_classes):
            if idx!=class_idx:
                f_name_path = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_one_class_{}_f_name.pth".format(logs_folder,dataset,model,mode,str(class_idx),str(idx))
                if PNH:
                    path_mode = r"C:\Users\Leonard\Desktop\ConLea\CSI_AISSAM\{}\{}_{}_unsup_simclr_CSI_shift_{}_one_class_{}\feats_1_one_class_{}_penultimate.pth".format(logs_folder,dataset,model,mode,str(class_idx),str(idx))
                
                feat = torch.load(path_mode)[:,i,:]
                print(feat.shape)
                F = torch.cat((F, feat), dim = 0)
                L += ["class "+str(item) for item in [idx]*torch.load(path_mode)[:,i,:].shape[0]]
                f_names_ += list(np.unique(torch.load(f_name_path), axis=0))

        F , L, f_names_ = F[100:], L[100:], f_names_[100:]
        L = [item for sublist in [[groups[str(idx)]]*150 for idx in np.arange(150,1500,150)] for item in sublist]
        
        F = torch.cat((F,torch.load(path_mode_train)[:,i,:]), dim=0)
        L += ["Anchor train"]*torch.load(path_mode_train)[:,i,:].shape[0]
        f_names_ += list(np.unique(torch.load(f_name_train), axis=0))
        
        # Add some samples from test:
        F = torch.cat((F, torch.load(path_mode_one_class)[:,i,:]), dim=0)
        L += ["Anchor test"]*torch.load(path_mode_one_class)[:,i,:].shape[0]
        f_names_ += list(np.unique(torch.load(f_name_one_class), axis=0))
    if D2:
        print("2D T-Sne Projections for model {} with Features size : {},{}".format(model, F.shape[0],F.shape[1]))
        tsne = TSNE(n_components=2, random_state=0 ,
                    perplexity=110.0,
                    learning_rate=200.0,
                    n_iter=1000,
                    min_grad_norm=1e-07,
                    verbose=0,
                    method='barnes_hut',
                    square_distances='legacy')
        projections = tsne.fit_transform(F)
        if Nomalized: 
            projections = projections / np.linalg.norm(projections, keepdims=True, axis=-1)
        if i==0:
            tsne_imgs[model] = projections 
        fig = px.scatter(
            projections, x=0, y=1,
            color=L, labels=L, symbol=L,
            color_discrete_sequence = colors
        )
    else:
        print("3D T-Sne Projections for {} with Features size : {},{}".format(model, F.shape[0],F.shape[1]))
        tsne = TSNE(n_components=3, random_state=0 ,
                perplexity=100.0,
                learning_rate=200.0,
                n_iter=1000,
                min_grad_norm=1e-07,
                verbose=0,
                method='barnes_hut',
                square_distances='legacy')

        projections = tsne.fit_transform(F)
        projections = projections / np.linalg.norm(projections, keepdims=True, axis=-1)

        fig = px.scatter_3d(
        projections, x=0, y=1, z=2, color = L,
        color_discrete_sequence = px.colors.qualitative.Dark24)

    fig.update_layout(
    margin = dict(l=10, r=40, t=30, b=10),
    paper_bgcolor = "LightSteelBlue",
    plot_bgcolor = "LightSteelBlue",
    width = 1000,
    height = 500,
    legend = dict(yanchor="top", y=1, xanchor="left", x=1),
    bargroupgap = 0.0,
    title = "Representation of {} Class VS Abnormal Classes".format("Normal"),
    yaxis_title = "Shifted samples with None : {} ".format(rot[str(i)]))
    fig.update_traces(marker_size=4)
    fig.show()


model = "resnet18_scratch"
projections = tsne_imgs[model]
no_of_images = len(projections) # number of images. It is recommended to use a square of 2 number
ellipside = True # elipsoid or rectangular visualization
image_width = 64 # width and height of each visualized images
image_names = f_names_

# use tsne to cluster images in 2 dimensions
reduced = projections
reduced_transformed = reduced - np.min(reduced, axis=0)
reduced_transformed /= np.max(reduced_transformed, axis=0)

image_xindex_sorted = np.argsort(np.sum(reduced_transformed, axis=1))

# draw all images in a merged image
merged_width = int(np.ceil(np.sqrt(no_of_images))*image_width) + 8000
merged_image = np.ones((merged_width, merged_width, 3), dtype='uint8')*200

ind, ood = [],[]
ind_f, ood_f = [],[]

for counter, index in tqdm(enumerate(image_xindex_sorted)):
    if ellipside:
        a = np.ceil(reduced_transformed[counter, 0] * (merged_width-image_width-1)+1)
        b = np.ceil(reduced_transformed[counter, 1] * (merged_width-image_width-1)+1)
        a = int(a - np.mod(a-1,image_width) + 1)
        b = int(b - np.mod(b-1,image_width) + 1)
        image_address = image_names[counter]
        img = np.asarray(Image.open(image_address.replace(".npz",".png")).resize((image_width, image_width)))
        img = np.expand_dims(img, axis=-1)
        
        if "one_class_train" in image_address:
            # for the anchor class:
            ind_f.append(image_address)
            ind.append((a, a+image_width, b, b+image_width))
        else:
            ood_f.append(image_address)
            ood.append((a, a+image_width, b, b+image_width))

        merged_image[a:a+image_width, b:b+image_width,:] = img[:,:,:]  
        
        
fig = px.imshow(merged_image)
save = False
colors = ["_","darkorange","yellow","green", "blue", "lightpink"]

groups = {"150":"Emphysema 0.1-0.5","300":"Emphysema 0.5-1","450":"Emphysema 1-2","600":"Emphysema 2-3", 
          "750":"Emphysema 3-5", "900":"Emphysema 5-10", "1050":"Emphysema 10-15", "1200":"Emphysema 15-20",
          "1350":"Emphysema 20-100"}
class_ = 0
for g in np.arange(150, 750, 150):    
    class_ += 1
    fig.add_trace( go.Scatter(
        x=[(ood_[2]+ood_[3])//2 for ood_ in ood[g-150:g]],
        y=[(ood_[0]+ood_[1])//2 for ood_ in ood[g-150:g]],
        mode="markers",
        marker = dict(color=colors[class_],),
        opacity = 0.3,
        name = groupes[str(g)]))
class_ = 0
for idx,(ind_,ood_) in tqdm(enumerate(zip(ind[:1000], ood[:1000]))):
    if idx%150==0 and idx<600:
        class_ += 1
    if idx==600:
        class_ += 1
    c = colors[class_]    
    fig.add_shape(
        type='rect',
        x0=ood_[2], x1=ood_[3], y0=ood_[0], y1=ood_[1],
        xref='x', yref='y',
        line_color=c)
    fig.add_shape(
        type='rect',
        x0=ind_[2], x1=ind_[3], y0=ind_[0], y1=ind_[1],
        xref='x', yref='y',
        line_color='cyan')
fig.update_layout(height=1000, width=1000, newshape=dict(line_color='cyan'))
fig.show()
if save:
    merged_image = Image.fromarray(merged_image)
    merged_image.save("test.png")
Editor is loading...