Untitled
unknown
plain_text
2 years ago
21 kB
6
Indexable
import os
import pathlib
import pickle
import random
import shutil
import uuid
from datetime import datetime
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from datasets import get_anomaly_detection_split, get_classification_split
from transforms import augment_signal_in_ranges, disturb_noise_augmentation, minmax_normalize
class SmallVAE(nn.Module):
def __init__(self, latent_dim=50, num_classes=2):
super(SmallVAE, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
# Encoder
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
self.fc1 = nn.Linear(32*7*7 + num_classes, 500)
self.fc2_mean = nn.Linear(500, latent_dim)
self.fc2_logvar = nn.Linear(500, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim + num_classes, 32*7*7)
self.conv3 = nn.ConvTranspose2d(
32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv4 = nn.ConvTranspose2d(
16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
def encode(self, x, y):
# print(f"{x.shape =}")
x = F.relu(self.conv1(x))
# print(x.shape)
x = F.relu(self.conv2(x))
# print(x.shape)
x = x.view(x.size(0), -1)
# print(x.shape)
x = torch.cat((x, y), dim=1)
# print(x.shape)
x = F.relu(self.fc1(x))
# print(x.shape)
z_mean = self.fc2_mean(x)
z_logvar = self.fc2_logvar(x)
# print(f"{z_mean.shape =}")
# print(f"{z_logvar.shape =}")
return z_mean, z_logvar
def reparameterize(self, z_mean, z_logvar):
std = torch.exp(0.5 * z_logvar)
eps = torch.randn_like(std)
z = z_mean + eps * std
return z
def decode(self, z, y):
# print(f"{z.shape =}")
# print(f"{y.shape =}")
z_cond = torch.cat((z, y), dim=1)
# print(f"{z_cond.shape =}")
x = F.relu(self.fc3(z_cond))
# print(f"{x.shape =}")
x = x.view(x.size(0), 32, 7, 7)
# print(f"{x.shape =}")
x = F.relu(self.conv3(x))
# print(f"{x.shape =}")
x = torch.sigmoid(self.conv4(x))
# print(f"{x.shape =}")
return x
def forward(self, x, y):
z_mean, z_logvar = self.encode(x, y)
z = self.reparameterize(z_mean, z_logvar)
x_hat = self.decode(z, y)
return x_hat, z_mean, z_logvar, z
def sample(self, z, y):
with torch.no_grad():
x_hat = self.decode(z, y)
return x_hat
def compute_kl_loss(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
def loss_function(x_recon, x, mu, logvar, alpha=1.0):
# print(f"{x_recon.shape =}")
# print(f"{x.shape =}")
mse_loss = F.mse_loss(x_recon, x, reduction="sum")
# bce_loss = F.binary_cross_entropy(x_recon, x, reduction='mean')
kld_loss = compute_kl_loss(mu, logvar)
# print(f"{mse_loss = }")
# print(f"{kld_loss = }")
total_loss = mse_loss + alpha * kld_loss
return total_loss, mse_loss, kld_loss
def plot_loss_hist(hist, title, filename):
plt.figure()
plt.plot(hist)
plt.title(title)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig(filename)
#################################################################################
# Code added
def parse_layer_string(s):
layers = []
for ss in s.split(","):
if "x" in ss:
# Denotes a block repetition operation
res, num = ss.split("x")
count = int(num)
layers += [(int(res), None) for _ in range(count)]
elif "u" in ss:
# Denotes a resolution upsampling operation
res, mixin = [int(a) for a in ss.split("u")]
layers.append((res, mixin))
elif "d" in ss:
# Denotes a resolution downsampling operation
res, down_rate = [int(a) for a in ss.split("d")]
layers.append((res, down_rate))
elif "t" in ss:
# Denotes a resolution transition operation
res1, res2 = [int(a) for a in ss.split("t")]
layers.append(((res1, res2), None))
else:
res = int(ss)
layers.append((res, None))
return layers
def parse_channel_string(s):
channel_config = {}
for ss in s.split(","):
res, in_channels = ss.split(":")
channel_config[int(res)] = int(in_channels)
return channel_config
def get_conv(
in_dim,
out_dim,
kernel_size,
stride,
padding,
zero_bias=True,
zero_weights=False,
groups=1,
):
c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups)
if zero_bias:
c.bias.data *= 0.0
if zero_weights:
c.weight.data *= 0.0
return c
def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1):
return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups)
def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1):
return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups)
class ResBlock(nn.Module):
def __init__(
self,
in_width,
middle_width,
out_width,
down_rate=None,
residual=False,
use_3x3=True,
zero_last=False,
):
super().__init__()
self.down_rate = down_rate
self.residual = residual
self.c1 = get_1x1(in_width, middle_width)
self.c2 = (
get_3x3(middle_width, middle_width)
if use_3x3
else get_1x1(middle_width, middle_width)
)
self.c3 = (
get_3x3(middle_width, middle_width)
if use_3x3
else get_1x1(middle_width, middle_width)
)
self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)
def forward(self, x):
xhat = self.c1(F.gelu(x))
xhat = self.c2(F.gelu(xhat))
xhat = self.c3(F.gelu(xhat))
xhat = self.c4(F.gelu(xhat))
out = x + xhat if self.residual else xhat
if self.down_rate is not None:
out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
return out
class Encoder(nn.Module):
def __init__(self, block_config_str, channel_config_str):
super().__init__()
self.in_conv = nn.Conv2d(1, 64, 3, stride=1, padding=1, bias=False)
block_config = parse_layer_string(block_config_str)
channel_config = parse_channel_string(channel_config_str)
blocks = []
for _, (res, down_rate) in enumerate(block_config):
if isinstance(res, tuple):
# Denotes transition to another resolution
res1, res2 = res
blocks.append(
nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False)
)
continue
in_channel = channel_config[res]
use_3x3 = res > 1
blocks.append(
ResBlock(
in_channel,
int(0.5 * in_channel),
in_channel,
down_rate=down_rate,
residual=True,
use_3x3=use_3x3,
)
)
# TODO: If the training is unstable try using scaling the weights
self.block_mod = nn.Sequential(*blocks)
# Latents
self.mu = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False)
self.logvar = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False)
def forward(self, input, cond_vec):
x = self.in_conv(input)
x = self.block_mod(x)
x = F.avg_pool2d(x, kernel_size=(1,6))
x = torch.cat((x, cond_vec), dim=1)
return self.mu(x), self.logvar(x)
class Decoder(nn.Module):
def __init__(self, input_res, block_config_str, channel_config_str):
super().__init__()
block_config = parse_layer_string(block_config_str)
channel_config = parse_channel_string(channel_config_str)
blocks = []
for i, (res, up_rate) in enumerate(block_config):
if isinstance(res, tuple):
# Denotes transition to another resolution
res1, res2 = res
blocks.append(
nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False)
)
continue
if i == 0:
blocks.append(nn.Upsample(size=(1,6), mode="nearest"))
if up_rate is not None:
blocks.append(nn.Upsample(scale_factor=up_rate, mode="nearest"))
continue
in_channel = channel_config[res]
use_3x3 = res > 1
blocks.append(
ResBlock(
in_channel,
int(0.5 * in_channel),
in_channel,
down_rate=None,
residual=True,
use_3x3=use_3x3,
)
)
# TODO: If the training is unstable try using scaling the weights
self.block_mod = nn.Sequential(*blocks)
self.last_conv = nn.Conv2d(channel_config[input_res], 1, 3, stride=1, padding=1)
def forward(self, input, cond_vec):
input = torch.cat((input, cond_vec), dim=1)
x = self.block_mod(input)
x = F.interpolate(x, size=(100, 800), mode='nearest')
print(f"{x.shape =}")
x = self.last_conv(x)
print(f"{x.shape =}")
return torch.sigmoid(x)
# Implementation of the Resnet-VAE using a ResNet backbone as encoder
# and Upsampling blocks as the decoder
class VAE(nn.Module):
def __init__(
self,
enc_block_str,
dec_block_str,
enc_channel_str,
dec_channel_str,
alpha=1.0,
lr=1e-3,
num_classes = 2
):
super().__init__()
self.input_res = 128
self.enc_block_str = enc_block_str
self.dec_block_str = dec_block_str
self.enc_channel_str = enc_channel_str
self.dec_channel_str = dec_channel_str
self.alpha = alpha
self.lr = lr
self.num_classes = num_classes
self.embedding_dim = 512
# Encoder architecture
self.enc = Encoder(self.enc_block_str, self.enc_channel_str)
# Decoder Architecture
self.dec = Decoder(self.input_res, self.dec_block_str, self.dec_channel_str)
#Conditonal Embedding
self.embedding = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.embedding_dim)
def encode(self, x, y):
B, _, _, _ = x.shape
cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1)
mu, logvar = self.enc(x, cond_vec)
return mu, logvar
def decode(self, z, y):
B, _, _, _ = z.shape
cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1)
return self.dec(z, cond_vec)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def compute_kl(self, mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
def forward(self, x, y):
mu, logvar = self.encode(x, y)
z = self.reparameterize(mu, logvar)
x_hat = self.decode(z, y)
return x_hat, mu, logvar, z
def sample(self, z, y):
# Only sample during inference
decoder_out = self.decode(z, y)
return decoder_out
def forward_recons(self, x):
# For generating reconstructions during inference
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
decoder_out = self.decode(z)
return decoder_out
#################################################################################
def main():
NUM_CLASSES = 2
LATENT_DIM = 50
AUGMENT = False
TRAIN = True
EVAL = False
EVAL_AND_SAVE = False
FROM_PRETRAINED = False
SAVE_MODEL = True
SAMPLE = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path0 = "./data/all_samples/Pkls"
(train_images, train_labels), (val_images,
val_labels) = get_anomaly_detection_split(path0)
(train_images, train_labels), (val_images,
val_labels) = get_classification_split(path0)
train_images = torch.tensor(train_images[:, None, :, :],
dtype=torch.float32, device=device)
train_labels = torch.as_tensor(train_labels, device=device)
# Not using one_hot encoding
# train_labels = torch.nn.functional.one_hot(
# train_labels, num_classes=NUM_CLASSES).float()
if AUGMENT:
train_images = augment_signal_in_ranges(
train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)])
train_images = disturb_noise_augmentation(
train_images, noise_level=1e-8, factor=1.2)
# TODO Inv signal spike
train_images = minmax_normalize(train_images)
print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}")
assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed."
train_dataset = TensorDataset(train_images, train_labels)
val_images = torch.tensor(val_images[:, None, :, :],
dtype=torch.float32, device=device)
val_labels = torch.as_tensor(val_labels, device=device)
# Not using one_hot encoding
# val_labels = torch.nn.functional.one_hot(
# val_labels, num_classes=NUM_CLASSES).float()
val_dataset = TensorDataset(val_images, val_labels)
train_dataloader = DataLoader(train_dataset, batch_size=16)
val_dataloader = DataLoader(val_dataset, batch_size=16)
####################################################################################
# Code added
# enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2"
enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2"
enc_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:512"
dec_block_config_str = "1x1,1u4,1t4,4x2,4u2,4t8,8x2,8u2,8t16,16x6,16u2,16t32,32x2,32u2,32t64,64x2,64u2,64t128,128x1"
dec_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:1024"
vae = VAE(
enc_block_config_str,
dec_block_config_str,
enc_channel_config_str,
dec_channel_config_str,
num_classes=2
)
####################################################################################
# model = SmallVAE(latent_dim=50, num_classes=10)
model = vae
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_history = []
kld_loss_history = []
recons_loss_history = []
epochs = 100
model.train()
for epoch in range(1, epochs+1):
train_loss = 0
kld_loss = 0
recons_loss = 0
for batch_idx, (data, target) in enumerate(train_dataloader):
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
# print(model.enc)
recon_batch, mu, logvar, _ = model(data, target)
total_loss, recons_loss, kld_loss = loss_function(recon_batch, data, mu, logvar)
total_loss.backward()
train_loss += total_loss
kld_loss += kld_loss
recons_loss += recons_loss
optimizer.step()
batch_loss = train_loss / len(train_dataloader.dataset)
kld_loss = kld_loss / len(train_dataloader.dataset)
recons_loss = recons_loss / len(train_dataloader.dataset)
kld_loss_history.append(kld_loss.to("cpu").detach().numpy())
recons_loss_history.append(recons_loss.to("cpu").detach().numpy())
loss_history.append(batch_loss.to("cpu").detach().numpy())
torch.cuda.empty_cache()
print(f"Epoch: {epoch}, Loss: {batch_loss.item()}, KLD: {kld_loss.item()}, RECONS: {recons_loss.item()}")
# Plot loss
plot_loss_hist(loss_history, "Total training loss", "loss.png")
plot_loss_hist(kld_loss_history, "KLD loss", "kld_loss.png")
plot_loss_hist(recons_loss_history, "Recons loss", "recons_loss.png")
sampling_path = pathlib.Path("mnist_sampling")
if sampling_path.exists():
shutil.rmtree(sampling_path)
sampling_path.mkdir(parents=True, exist_ok=True)
model.eval()
num_samples = 10
latent_dim = 50
real_latent_dim = 512
num_classes = 2
all_samples = []
samples_class = []
with torch.no_grad():
for k in range(num_classes):
folder_path = sampling_path / str(k)
folder_path.mkdir(parents=True, exist_ok=True)
# z = torch.randn(num_samples, latent_dim, 1, 1).to(device)
z = torch.randn(num_samples, real_latent_dim, 1, 1).to(device)
# one_hot_y = torch.zeros((num_samples, num_classes)).to(device)
# one_hot_y[:, k] = 1
y = torch.full((num_samples,), k).to(device)
samples = model.sample(z, y)
all_samples.append(samples.to("cpu").detach().numpy())
samples_class.append(y.to("cpu").detach().numpy())
samples = samples.permute(0, 2, 3, 1).to("cpu").detach().numpy()
# Save as pickle and as image
for k, x_hat in enumerate(samples):
# denormalize
x_hat = (x_hat * 0.3081) + 0.1307
file_stem = f"{uuid.uuid4()}"
plt.figure()
plt.imshow(x_hat)
plt.title(f"sample class {k}")
# save the image
image_path = folder_path / file_stem
plt.savefig(str(image_path) + ".png")
plt.close()
all_samples = np.concatenate(all_samples, axis=0)
# samples_class = np.concatenate(samples_class, axis=0)
samples_class = np.concatenate(samples_class)
all_samples = torch.from_numpy(all_samples).to(device)
samples_class = torch.from_numpy(samples_class).to(device)
recon_batch, mu, logvar, z = model(all_samples, samples_class)
# targets = np.argmax(samples_class.to("cpu").numpy(), axis=1)
targets = samples_class.to("cpu").numpy()
# print(f"{targets = }")
# print(f"{z.shape = }")
N, C, *_ = z.shape
tsne_z = z.squeeze(dim=(2, 3))
# reshape the tensor to have shape (N, C)
tsne_z = tsne_z.view(N, C)
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(tsne_z.to("cpu").detach().numpy())
plt.figure()
plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=targets)
plt.title("t-SNE sampled data")
plt.savefig("tsne_sampling.png")
plt.close()
# plot the latent space distribution of training data using t-SNE
results = []
labels = []
for data, target in train_dataloader:
# target = F.one_hot(target, num_classes=10)
data = data.to(device)
target = target.to(device)
recon_batch, mu, logvar, z = model(data, target)
N, C, *_ = z.shape
tsne_z = z.squeeze(dim=(2, 3))
# reshape the tensor to have shape (16, 512)
tsne_z = tsne_z.view(N, C)
results.append(tsne_z.to("cpu").detach().numpy())
labels.append(target.to("cpu").detach().numpy())
results = np.concatenate(results, axis=0)
labels = np.concatenate(labels)
# labels = np.concatenate(labels, axis=0)
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(results)
plt.figure()
# plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=np.argmax(labels, axis=1))
plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels)
plt.title("t-SNE training latent space")
plt.savefig("tsne_training.png")
plt.close()
if __name__ == '__main__':
main()
Editor is loading...