Untitled
unknown
plain_text
a year ago
17 kB
1
Indexable
Never
# -*- coding: utf-8 -*- """ Spyder Editor This is a temporary script file. """ import pathlib import pickle import random import shutil from datetime import datetime from typing import List, Tuple import numpy as np import torch from torch.utils.data import DataLoader, TensorDataset from transforms import (augment_signal_in_ranges, disturb_noise_augmentation, minmax_normalize) # from sklearn.manifold.t_sne import TSNE # %% define the encoder class Encoder(torch.nn.Module): def __init__(self, latent_dim: int, num_classes: int): super(Encoder, self).__init__() self.cnn1 = torch.nn.Conv2d(1, 20, (4, 32), (2, 16)) self.cnn2 = torch.nn.Conv2d(20, 50, (5, 5), (2, 2)) self.cnn3 = torch.nn.Conv2d(50, 100, (5, 5), (2, 2)) self.cnn4 = torch.nn.Conv2d(100, 200, (4, 4), (2, 2)) self.cnn5 = torch.nn.Conv2d(200, 500, (4, 4), (1, 1)) # cat class label here for conditional VAE self.linear1 = torch.nn.Linear(500 + num_classes, latent_dim) self.linear2 = torch.nn.Linear(500 + num_classes, 1) def forward(self, x, y): """ x: torch.Tensor The input tensor. The tensor should have shape (batch_size, channels, rows, columns). y: torch.Tensor The class label. The tensor should have shape (batch_size, num_classes). """ x = -torch.log(1/(0.001 + 0.998*x) - 1) #print(x.shape, torch.std(x)) x = self.cnn1(x/2) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnn2(x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnn3(2*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnn4(3*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnn5(2*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = torch.flatten(x, 1, 3) x = torch.cat((x, y), dim=1) mu = self.linear1(x) sigma = torch.exp(self.linear2(x)) # return mean and variance**0.5 return (mu, sigma) # %% define the decoder class Decoder(torch.nn.Module): def __init__(self, latent_dim: int, num_classes: int): super(Decoder, self).__init__() self.linear1 = torch.nn.Linear( latent_dim + num_classes, 500) # cat class label here self.cnnt1 = torch.nn.ConvTranspose2d(500, 200, (4, 4), (1, 1)) self.cnnt2 = torch.nn.ConvTranspose2d(200, 100, (4, 4), (2, 2)) self.cnnt3 = torch.nn.ConvTranspose2d(100, 50, (5, 5), (2, 2)) self.cnnt4 = torch.nn.ConvTranspose2d(50, 20, (5, 5), (2, 2)) self.cnnt5 = torch.nn.ConvTranspose2d(20, 1, (4, 32), (2, 16)) def forward(self, x, y): """ x: torch.Tensor The latent vector. The tensor should have shape (batch_size, latent_dim). y: torch.Tensor The class label. The tensor should have shape (batch_size, num_classes). """ x = torch.cat((x, y), dim=1) x = self.linear1(5*x) #print(x.shape, torch.std(x)) x = self.cnnt1(5*x[:, :, None, None]) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnnt2(3*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnnt3(5*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnnt4(3*x) #print(x.shape, torch.std(x)) x = torch.max(x, 0.1*x) x = self.cnnt5(5*x) #print(x.shape, torch.std(x)) mu = torch.sigmoid(x - 5) # always assume variance=1. so only return the mean return mu # %% define the classifier class Classifier(torch.nn.Module): def __init__(self, num_classes: int): super(Classifier, self).__init__() # as a binary classification problem, the output dim is 1! self.linear1 = torch.nn.Linear(50, num_classes) self.activation = torch.nn.Softmax(dim=1) def forward(self, x): x = self.linear1(x) x = self.activation(x) return x # %% define the dataset VAL_FRAC = 0.2 def get_anomaly_detection_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: """ Train split contains 80% Negative samples and val split contains all positives and the 20% of negatives """ files = os.listdir(dir_path) total_negative = 0 total_positive = 0 train_images = [] train_labels = [] val_images = [] val_labels = [] for file in files: with open(os.path.join(dir_path, file), 'rb') as f: sample = pickle.load(f) if sample.get("label_name") == "NEGATIVE": total_negative += 1 if random.random() < VAL_FRAC: val_images.append(sample['data']) val_labels.append(sample['label_idx']) else: train_images.append(sample['data']) train_labels.append(sample['label_idx']) else: # All Positive samples go to validation total_positive += 1 val_images.append(sample['data']) val_labels.append(sample['label_idx']) print('number of train images: {}'.format(len(train_images))) print('number of val images: {}'.format(len(val_images))) for image in train_images: print('image size: {}; min: {}; max: {}'.format( image.shape, np.min(image), np.max(image))) train_images = np.stack(train_images, axis=0) val_images = np.stack(val_images, axis=0) return (train_images, train_labels), (val_images, val_labels) def get_classification_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: """ Train split contains 80% of all samples and val split contains the 20% of all samples """ from sklearn.model_selection import train_test_split total_negative = 0 total_positive = 0 train_images = [] train_labels = [] val_images = [] val_labels = [] X, y = [], [] files = os.listdir(dir_path) for file in files: with open(os.path.join(dir_path, file), "rb") as f: sample = pickle.load(f) X.append(sample['data']) y.append(sample['label_idx']) train_images, val_images, train_labels, val_labels = train_test_split( X, y, test_size=VAL_FRAC, random_state=42) train_images = np.stack(train_images, axis=0) val_images = np.stack(val_images, axis=0) train_labels = np.stack(train_labels, axis=0) val_labels = np.stack(val_labels, axis=0) print('Num train samples: {}'.format(len(train_images))) print('Num val samples: {}'.format(len(val_images))) return (train_images, train_labels), (val_images, val_labels) if __name__ == "__main__": import os import pickle import uuid import matplotlib.pyplot as plt import numpy as np # %% load images NUM_CLASSES = 2 LATENT_DIM = 50 AUGMENT = False TRAIN = True EVAL = False EVAL_AND_SAVE = False FROM_PRETRAINED = False SAVE_MODEL = True SAMPLE = True # path0 = 'C:/Users/16692/Downloads/Pkls' path0 = "./data/all_samples/Pkls" # send images, labels and models to cuda device = torch.device('cuda') # (train_images, train_labels), (val_images, # val_labels) = get_anomaly_detection_split(path0) (train_images, train_labels), (val_images, val_labels) = get_classification_split(path0) train_images = torch.tensor(train_images[:, None, :, :], dtype=torch.float32, device=device) train_labels = torch.as_tensor(train_labels, device=device) train_labels = torch.nn.functional.one_hot( train_labels, num_classes=NUM_CLASSES).float() if AUGMENT: train_images = augment_signal_in_ranges( train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)]) train_images = disturb_noise_augmentation( train_images, noise_level=1e-8, factor=1.2) # TODO Inv signal spike train_images = minmax_normalize(train_images) print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}") assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed." train_dataset = TensorDataset(train_images, train_labels) val_images = torch.tensor(val_images[:, None, :, :], dtype=torch.float32, device=device) val_labels = torch.as_tensor(val_labels, device=device) val_labels = torch.nn.functional.one_hot( val_labels, num_classes=NUM_CLASSES).float() val_dataset = TensorDataset(val_images, val_labels) train_dataloader = DataLoader(train_dataset, batch_size=32) val_dataloader = DataLoader(val_dataset, batch_size=32) encoder = Encoder(latent_dim=LATENT_DIM, num_classes=NUM_CLASSES).to(device) decoder = Decoder(latent_dim=LATENT_DIM, num_classes=NUM_CLASSES).to(device) classifier = Classifier(num_classes=NUM_CLASSES).to(device) threshold = None if FROM_PRETRAINED: checkpoint = torch.load( "./trained_models/model_04-04-2023_13:15:42.pth") encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) classifier.load_state_dict(checkpoint['classifier_state_dict']) threshold = checkpoint['threshold'] # %% use the adam optimizer if TRAIN: print("Training job") opt = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()) + list(classifier.parameters()), lr=1e-4) # %% training loops begin here num_iters = 30 Loss = [] encoder.train() decoder.train() classifier.train() threshold = torch.zeros(len(train_images)).to(device) l1_loss = torch.nn.L1Loss() reconstruction_losses = [] for i in range(num_iters): def total_loss(images, labels): # encoder loss x = images y_true = labels mu, sigma = encoder(x, y_true) # KL divergence loss divergence = torch.sum( 0.5*mu**2 + 0.5*sigma**2 - torch.log(sigma)) # classification loss # or first sample to get z; then classify with z y_pred = classifier(mu) xentropy = torch.mean( torch.log(1 + torch.exp(-y_true*torch.squeeze(y_pred)))) # decoder loss z = mu + sigma*torch.randn_like(mu) mu = decoder(z, y_true) distortion = torch.sum(0.5*(x - mu)**2) reconstruction_losses.append(distortion.detach().cpu()) # assign proper weights to each loss, and return their sum return 1.0*(divergence + distortion) + 1.0*xentropy for batch_data, batch_targets in train_dataloader: opt.zero_grad() loss = total_loss(batch_data, batch_targets) loss.backward() opt.step() Loss.append(loss.item()) print('iter: {}; loss: {}'.format(i, Loss[-1])) # Calculating the threshold # for sample_data, target_data in train_dataloader: # print(sample_data.shape) # print(target_data.shape) # for idx, sample in enumerate(sample_data): # sample = torch.unsqueeze(sample, axis=0) # print(sample.shape) # target = target_data[idx, : ] # mu, _ = encoder(sample, target) # mu = decoder(mu, target) # l1_loss_result = l1_loss(sample, mu) # # reconstruction_losses.append(l1_loss_result.detach().cpu()) # threshold = torch.maximum(threshold, l1_loss_result) plt.plot(Loss) # plt.hist(reconstruction_losses, bins=50) plt.savefig("loss.png") plt.show() if SAVE_MODEL: now = datetime.now().strftime("%m-%d-%Y_%H:%M:%S") torch.save({ 'encoder_state_dict': encoder.state_dict(), 'decoder_state_dict': decoder.state_dict(), 'classifier_state_dict': classifier.state_dict(), 'threshold': threshold }, f"./trained_models/model_{now}.pth") # Evaluating the reconstruction loss if EVAL: print("Evaluating") encoder.eval() decoder.eval() classifier.eval() negative_reconstruction_losses = [] positive_reconstruction_losses = [] with torch.no_grad(): l1_loss = torch.nn.L1Loss() mu, _ = encoder(val_images, val_labels) mu = decoder(mu, val_labels) for i in range(len(val_images)): assert val_images[i].shape == mu[i].shape output = l1_loss(val_images[i], mu[i]) output = output.detach().cpu() if val_labels[i] == 0: positive_reconstruction_losses.append(output) else: negative_reconstruction_losses.append(output) plt.hist(negative_reconstruction_losses, bins=50, label='negative') plt.hist(positive_reconstruction_losses, bins=50, label='positive') plt.legend(loc='upper right') plt.savefig("loss_reconstruction.png") # %% let's check the reconstructed images if EVAL_AND_SAVE: print("Evaluating") encoder.eval() decoder.eval() classifier.eval() with torch.no_grad(): mu, _ = encoder(val_images, val_labels) y = classifier(mu) mu = decoder(mu, val_labels) y = y.detach().cpu().numpy() mu = mu.detach().cpu().numpy() images0 = val_images.detach().cpu().numpy() for i in range(len(val_images)): title = f"reconstructed - Classification {str(float(y[0]))} \n original class {val_labels[i]}" plt.subplot(121) plt.imshow(mu[i, 0]) plt.title(title) plt.subplot(122) plt.imshow(images0[i, 0]) plt.title("original") plt.show() plt.savefig( f"reconstructed{os.sep}{uuid.uuid4()}_{val_labels[i]}.png") if SAMPLE: print("Sampling") sampling_path = pathlib.Path("gen_samples") if sampling_path.exists(): shutil.rmtree(sampling_path) encoder.eval() decoder.eval() classifier.eval() with torch.no_grad(): for k in range(NUM_CLASSES): folder_path = sampling_path / f"class_{k}" plot_path = folder_path / "plots" pkls_path = folder_path / "pkls" plot_path.mkdir(parents=True, exist_ok=True) pkls_path.mkdir(parents=True, exist_ok=True) for i in range(30): z = torch.randn(1, LATENT_DIM).to(device) one_hot_y = torch.zeros(1, NUM_CLASSES).to(device) one_hot_y[0, k] = 1 x_hat_tensor = decoder(z, one_hot_y).permute( 0, 2, 3, 1).detach().cpu().numpy() x_hat = np.squeeze(x_hat_tensor, axis=0) # Save as pickle and as image file_stem = f"{uuid.uuid4()}_class_{k}" pkl_path = pkls_path / file_stem with open(str(pkl_path) + ".pkl", "wb") as f: pickle.dump(x_hat, f) plt.figure() plt.imshow(x_hat) plt.title(f"sample class {k}") # save the image image_path = plot_path / file_stem plt.savefig(str(image_path) + ".png") x_hat_tensor = torch.from_numpy(x_hat_tensor) mu, sigma = encoder(x_hat_tensor, one_hot_y)