Untitled
unknown
plain_text
3 years ago
17 kB
9
Indexable
# -*- coding: utf-8 -*-
"""
Spyder Editor
This is a temporary script file.
"""
import pathlib
import pickle
import random
import shutil
from datetime import datetime
from typing import List, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from transforms import (augment_signal_in_ranges, disturb_noise_augmentation,
minmax_normalize)
# from sklearn.manifold.t_sne import TSNE
# %% define the encoder
class Encoder(torch.nn.Module):
def __init__(self, latent_dim: int, num_classes: int):
super(Encoder, self).__init__()
self.cnn1 = torch.nn.Conv2d(1, 20, (4, 32), (2, 16))
self.cnn2 = torch.nn.Conv2d(20, 50, (5, 5), (2, 2))
self.cnn3 = torch.nn.Conv2d(50, 100, (5, 5), (2, 2))
self.cnn4 = torch.nn.Conv2d(100, 200, (4, 4), (2, 2))
self.cnn5 = torch.nn.Conv2d(200, 500, (4, 4), (1, 1))
# cat class label here for conditional VAE
self.linear1 = torch.nn.Linear(500 + num_classes, latent_dim)
self.linear2 = torch.nn.Linear(500 + num_classes, 1)
def forward(self, x, y):
"""
x: torch.Tensor
The input tensor. The tensor should have shape (batch_size, channels, rows, columns).
y: torch.Tensor
The class label. The tensor should have shape (batch_size, num_classes).
"""
x = -torch.log(1/(0.001 + 0.998*x) - 1)
#print(x.shape, torch.std(x))
x = self.cnn1(x/2)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnn2(x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnn3(2*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnn4(3*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnn5(2*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = torch.flatten(x, 1, 3)
x = torch.cat((x, y), dim=1)
mu = self.linear1(x)
sigma = torch.exp(self.linear2(x))
# return mean and variance**0.5
return (mu, sigma)
# %% define the decoder
class Decoder(torch.nn.Module):
def __init__(self, latent_dim: int, num_classes: int):
super(Decoder, self).__init__()
self.linear1 = torch.nn.Linear(
latent_dim + num_classes, 500) # cat class label here
self.cnnt1 = torch.nn.ConvTranspose2d(500, 200, (4, 4), (1, 1))
self.cnnt2 = torch.nn.ConvTranspose2d(200, 100, (4, 4), (2, 2))
self.cnnt3 = torch.nn.ConvTranspose2d(100, 50, (5, 5), (2, 2))
self.cnnt4 = torch.nn.ConvTranspose2d(50, 20, (5, 5), (2, 2))
self.cnnt5 = torch.nn.ConvTranspose2d(20, 1, (4, 32), (2, 16))
def forward(self, x, y):
"""
x: torch.Tensor
The latent vector. The tensor should have shape (batch_size, latent_dim).
y: torch.Tensor
The class label. The tensor should have shape (batch_size, num_classes).
"""
x = torch.cat((x, y), dim=1)
x = self.linear1(5*x)
#print(x.shape, torch.std(x))
x = self.cnnt1(5*x[:, :, None, None])
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnnt2(3*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnnt3(5*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnnt4(3*x)
#print(x.shape, torch.std(x))
x = torch.max(x, 0.1*x)
x = self.cnnt5(5*x)
#print(x.shape, torch.std(x))
mu = torch.sigmoid(x - 5)
# always assume variance=1. so only return the mean
return mu
# %% define the classifier
class Classifier(torch.nn.Module):
def __init__(self, num_classes: int):
super(Classifier, self).__init__()
# as a binary classification problem, the output dim is 1!
self.linear1 = torch.nn.Linear(50, num_classes)
self.activation = torch.nn.Softmax(dim=1)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
return x
# %% define the dataset
VAL_FRAC = 0.2
def get_anomaly_detection_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
"""
Train split contains 80% Negative samples and val split contains all positives and the 20% of negatives
"""
files = os.listdir(dir_path)
total_negative = 0
total_positive = 0
train_images = []
train_labels = []
val_images = []
val_labels = []
for file in files:
with open(os.path.join(dir_path, file), 'rb') as f:
sample = pickle.load(f)
if sample.get("label_name") == "NEGATIVE":
total_negative += 1
if random.random() < VAL_FRAC:
val_images.append(sample['data'])
val_labels.append(sample['label_idx'])
else:
train_images.append(sample['data'])
train_labels.append(sample['label_idx'])
else: # All Positive samples go to validation
total_positive += 1
val_images.append(sample['data'])
val_labels.append(sample['label_idx'])
print('number of train images: {}'.format(len(train_images)))
print('number of val images: {}'.format(len(val_images)))
for image in train_images:
print('image size: {}; min: {}; max: {}'.format(
image.shape, np.min(image), np.max(image)))
train_images = np.stack(train_images, axis=0)
val_images = np.stack(val_images, axis=0)
return (train_images, train_labels), (val_images, val_labels)
def get_classification_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
"""
Train split contains 80% of all samples and val split contains the 20% of all samples
"""
from sklearn.model_selection import train_test_split
total_negative = 0
total_positive = 0
train_images = []
train_labels = []
val_images = []
val_labels = []
X, y = [], []
files = os.listdir(dir_path)
for file in files:
with open(os.path.join(dir_path, file), "rb") as f:
sample = pickle.load(f)
X.append(sample['data'])
y.append(sample['label_idx'])
train_images, val_images, train_labels, val_labels = train_test_split(
X, y, test_size=VAL_FRAC, random_state=42)
train_images = np.stack(train_images, axis=0)
val_images = np.stack(val_images, axis=0)
train_labels = np.stack(train_labels, axis=0)
val_labels = np.stack(val_labels, axis=0)
print('Num train samples: {}'.format(len(train_images)))
print('Num val samples: {}'.format(len(val_images)))
return (train_images, train_labels), (val_images, val_labels)
if __name__ == "__main__":
import os
import pickle
import uuid
import matplotlib.pyplot as plt
import numpy as np
# %% load images
NUM_CLASSES = 2
LATENT_DIM = 50
AUGMENT = False
TRAIN = True
EVAL = False
EVAL_AND_SAVE = False
FROM_PRETRAINED = False
SAVE_MODEL = True
SAMPLE = True
# path0 = 'C:/Users/16692/Downloads/Pkls'
path0 = "./data/all_samples/Pkls"
# send images, labels and models to cuda
device = torch.device('cuda')
# (train_images, train_labels), (val_images,
# val_labels) = get_anomaly_detection_split(path0)
(train_images, train_labels), (val_images,
val_labels) = get_classification_split(path0)
train_images = torch.tensor(train_images[:, None, :, :],
dtype=torch.float32, device=device)
train_labels = torch.as_tensor(train_labels, device=device)
train_labels = torch.nn.functional.one_hot(
train_labels, num_classes=NUM_CLASSES).float()
if AUGMENT:
train_images = augment_signal_in_ranges(
train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)])
train_images = disturb_noise_augmentation(
train_images, noise_level=1e-8, factor=1.2)
# TODO Inv signal spike
train_images = minmax_normalize(train_images)
print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}")
assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed."
train_dataset = TensorDataset(train_images, train_labels)
val_images = torch.tensor(val_images[:, None, :, :],
dtype=torch.float32, device=device)
val_labels = torch.as_tensor(val_labels, device=device)
val_labels = torch.nn.functional.one_hot(
val_labels, num_classes=NUM_CLASSES).float()
val_dataset = TensorDataset(val_images, val_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)
encoder = Encoder(latent_dim=LATENT_DIM,
num_classes=NUM_CLASSES).to(device)
decoder = Decoder(latent_dim=LATENT_DIM,
num_classes=NUM_CLASSES).to(device)
classifier = Classifier(num_classes=NUM_CLASSES).to(device)
threshold = None
if FROM_PRETRAINED:
checkpoint = torch.load(
"./trained_models/model_04-04-2023_13:15:42.pth")
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])
threshold = checkpoint['threshold']
# %% use the adam optimizer
if TRAIN:
print("Training job")
opt = torch.optim.Adam(list(encoder.parameters()) +
list(decoder.parameters()) +
list(classifier.parameters()),
lr=1e-4)
# %% training loops begin here
num_iters = 30
Loss = []
encoder.train()
decoder.train()
classifier.train()
threshold = torch.zeros(len(train_images)).to(device)
l1_loss = torch.nn.L1Loss()
reconstruction_losses = []
for i in range(num_iters):
def total_loss(images, labels):
# encoder loss
x = images
y_true = labels
mu, sigma = encoder(x, y_true)
# KL divergence loss
divergence = torch.sum(
0.5*mu**2 + 0.5*sigma**2 - torch.log(sigma))
# classification loss
# or first sample to get z; then classify with z
y_pred = classifier(mu)
xentropy = torch.mean(
torch.log(1 + torch.exp(-y_true*torch.squeeze(y_pred))))
# decoder loss
z = mu + sigma*torch.randn_like(mu)
mu = decoder(z, y_true)
distortion = torch.sum(0.5*(x - mu)**2)
reconstruction_losses.append(distortion.detach().cpu())
# assign proper weights to each loss, and return their sum
return 1.0*(divergence + distortion) + 1.0*xentropy
for batch_data, batch_targets in train_dataloader:
opt.zero_grad()
loss = total_loss(batch_data, batch_targets)
loss.backward()
opt.step()
Loss.append(loss.item())
print('iter: {}; loss: {}'.format(i, Loss[-1]))
# Calculating the threshold
# for sample_data, target_data in train_dataloader:
# print(sample_data.shape)
# print(target_data.shape)
# for idx, sample in enumerate(sample_data):
# sample = torch.unsqueeze(sample, axis=0)
# print(sample.shape)
# target = target_data[idx, : ]
# mu, _ = encoder(sample, target)
# mu = decoder(mu, target)
# l1_loss_result = l1_loss(sample, mu)
# # reconstruction_losses.append(l1_loss_result.detach().cpu())
# threshold = torch.maximum(threshold, l1_loss_result)
plt.plot(Loss)
# plt.hist(reconstruction_losses, bins=50)
plt.savefig("loss.png")
plt.show()
if SAVE_MODEL:
now = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
torch.save({
'encoder_state_dict': encoder.state_dict(),
'decoder_state_dict': decoder.state_dict(),
'classifier_state_dict': classifier.state_dict(),
'threshold': threshold
}, f"./trained_models/model_{now}.pth")
# Evaluating the reconstruction loss
if EVAL:
print("Evaluating")
encoder.eval()
decoder.eval()
classifier.eval()
negative_reconstruction_losses = []
positive_reconstruction_losses = []
with torch.no_grad():
l1_loss = torch.nn.L1Loss()
mu, _ = encoder(val_images, val_labels)
mu = decoder(mu, val_labels)
for i in range(len(val_images)):
assert val_images[i].shape == mu[i].shape
output = l1_loss(val_images[i], mu[i])
output = output.detach().cpu()
if val_labels[i] == 0:
positive_reconstruction_losses.append(output)
else:
negative_reconstruction_losses.append(output)
plt.hist(negative_reconstruction_losses, bins=50, label='negative')
plt.hist(positive_reconstruction_losses, bins=50, label='positive')
plt.legend(loc='upper right')
plt.savefig("loss_reconstruction.png")
# %% let's check the reconstructed images
if EVAL_AND_SAVE:
print("Evaluating")
encoder.eval()
decoder.eval()
classifier.eval()
with torch.no_grad():
mu, _ = encoder(val_images, val_labels)
y = classifier(mu)
mu = decoder(mu, val_labels)
y = y.detach().cpu().numpy()
mu = mu.detach().cpu().numpy()
images0 = val_images.detach().cpu().numpy()
for i in range(len(val_images)):
title = f"reconstructed - Classification {str(float(y[0]))} \n original class {val_labels[i]}"
plt.subplot(121)
plt.imshow(mu[i, 0])
plt.title(title)
plt.subplot(122)
plt.imshow(images0[i, 0])
plt.title("original")
plt.show()
plt.savefig(
f"reconstructed{os.sep}{uuid.uuid4()}_{val_labels[i]}.png")
if SAMPLE:
print("Sampling")
sampling_path = pathlib.Path("gen_samples")
if sampling_path.exists():
shutil.rmtree(sampling_path)
encoder.eval()
decoder.eval()
classifier.eval()
with torch.no_grad():
for k in range(NUM_CLASSES):
folder_path = sampling_path / f"class_{k}"
plot_path = folder_path / "plots"
pkls_path = folder_path / "pkls"
plot_path.mkdir(parents=True, exist_ok=True)
pkls_path.mkdir(parents=True, exist_ok=True)
for i in range(30):
z = torch.randn(1, LATENT_DIM).to(device)
one_hot_y = torch.zeros(1, NUM_CLASSES).to(device)
one_hot_y[0, k] = 1
x_hat_tensor = decoder(z, one_hot_y).permute(
0, 2, 3, 1).detach().cpu().numpy()
x_hat = np.squeeze(x_hat_tensor, axis=0)
# Save as pickle and as image
file_stem = f"{uuid.uuid4()}_class_{k}"
pkl_path = pkls_path / file_stem
with open(str(pkl_path) + ".pkl", "wb") as f:
pickle.dump(x_hat, f)
plt.figure()
plt.imshow(x_hat)
plt.title(f"sample class {k}")
# save the image
image_path = plot_path / file_stem
plt.savefig(str(image_path) + ".png")
x_hat_tensor = torch.from_numpy(x_hat_tensor)
mu, sigma = encoder(x_hat_tensor, one_hot_y)
Editor is loading...