Untitled
unknown
plain_text
a year ago
21 kB
1
Indexable
Never
import os import pathlib import pickle import random import shutil import uuid from datetime import datetime from typing import Tuple import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn.manifold import TSNE from torch import nn from torch.utils.data import DataLoader, TensorDataset from torchvision import datasets, transforms from datasets import get_anomaly_detection_split, get_classification_split from transforms import augment_signal_in_ranges, disturb_noise_augmentation, minmax_normalize class SmallVAE(nn.Module): def __init__(self, latent_dim=50, num_classes=2): super(SmallVAE, self).__init__() self.latent_dim = latent_dim self.num_classes = num_classes # Encoder self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) self.fc1 = nn.Linear(32*7*7 + num_classes, 500) self.fc2_mean = nn.Linear(500, latent_dim) self.fc2_logvar = nn.Linear(500, latent_dim) # Decoder self.fc3 = nn.Linear(latent_dim + num_classes, 32*7*7) self.conv3 = nn.ConvTranspose2d( 32, 16, kernel_size=3, stride=2, padding=1, output_padding=1) self.conv4 = nn.ConvTranspose2d( 16, 1, kernel_size=3, stride=2, padding=1, output_padding=1) def encode(self, x, y): # print(f"{x.shape =}") x = F.relu(self.conv1(x)) # print(x.shape) x = F.relu(self.conv2(x)) # print(x.shape) x = x.view(x.size(0), -1) # print(x.shape) x = torch.cat((x, y), dim=1) # print(x.shape) x = F.relu(self.fc1(x)) # print(x.shape) z_mean = self.fc2_mean(x) z_logvar = self.fc2_logvar(x) # print(f"{z_mean.shape =}") # print(f"{z_logvar.shape =}") return z_mean, z_logvar def reparameterize(self, z_mean, z_logvar): std = torch.exp(0.5 * z_logvar) eps = torch.randn_like(std) z = z_mean + eps * std return z def decode(self, z, y): # print(f"{z.shape =}") # print(f"{y.shape =}") z_cond = torch.cat((z, y), dim=1) # print(f"{z_cond.shape =}") x = F.relu(self.fc3(z_cond)) # print(f"{x.shape =}") x = x.view(x.size(0), 32, 7, 7) # print(f"{x.shape =}") x = F.relu(self.conv3(x)) # print(f"{x.shape =}") x = torch.sigmoid(self.conv4(x)) # print(f"{x.shape =}") return x def forward(self, x, y): z_mean, z_logvar = self.encode(x, y) z = self.reparameterize(z_mean, z_logvar) x_hat = self.decode(z, y) return x_hat, z_mean, z_logvar, z def sample(self, z, y): with torch.no_grad(): x_hat = self.decode(z, y) return x_hat def compute_kl_loss(mu, logvar): return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) def loss_function(x_recon, x, mu, logvar, alpha=1.0): # print(f"{x_recon.shape =}") # print(f"{x.shape =}") mse_loss = F.mse_loss(x_recon, x, reduction="sum") # bce_loss = F.binary_cross_entropy(x_recon, x, reduction='mean') kld_loss = compute_kl_loss(mu, logvar) # print(f"{mse_loss = }") # print(f"{kld_loss = }") total_loss = mse_loss + alpha * kld_loss return total_loss, mse_loss, kld_loss def plot_loss_hist(hist, title, filename): plt.figure() plt.plot(hist) plt.title(title) plt.xlabel("Epoch") plt.ylabel("Loss") plt.savefig(filename) ################################################################################# # Code added def parse_layer_string(s): layers = [] for ss in s.split(","): if "x" in ss: # Denotes a block repetition operation res, num = ss.split("x") count = int(num) layers += [(int(res), None) for _ in range(count)] elif "u" in ss: # Denotes a resolution upsampling operation res, mixin = [int(a) for a in ss.split("u")] layers.append((res, mixin)) elif "d" in ss: # Denotes a resolution downsampling operation res, down_rate = [int(a) for a in ss.split("d")] layers.append((res, down_rate)) elif "t" in ss: # Denotes a resolution transition operation res1, res2 = [int(a) for a in ss.split("t")] layers.append(((res1, res2), None)) else: res = int(ss) layers.append((res, None)) return layers def parse_channel_string(s): channel_config = {} for ss in s.split(","): res, in_channels = ss.split(":") channel_config[int(res)] = int(in_channels) return channel_config def get_conv( in_dim, out_dim, kernel_size, stride, padding, zero_bias=True, zero_weights=False, groups=1, ): c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups) if zero_bias: c.bias.data *= 0.0 if zero_weights: c.weight.data *= 0.0 return c def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1): return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups) def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1): return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups) class ResBlock(nn.Module): def __init__( self, in_width, middle_width, out_width, down_rate=None, residual=False, use_3x3=True, zero_last=False, ): super().__init__() self.down_rate = down_rate self.residual = residual self.c1 = get_1x1(in_width, middle_width) self.c2 = ( get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width) ) self.c3 = ( get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width) ) self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last) def forward(self, x): xhat = self.c1(F.gelu(x)) xhat = self.c2(F.gelu(xhat)) xhat = self.c3(F.gelu(xhat)) xhat = self.c4(F.gelu(xhat)) out = x + xhat if self.residual else xhat if self.down_rate is not None: out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate) return out class Encoder(nn.Module): def __init__(self, block_config_str, channel_config_str): super().__init__() self.in_conv = nn.Conv2d(1, 64, 3, stride=1, padding=1, bias=False) block_config = parse_layer_string(block_config_str) channel_config = parse_channel_string(channel_config_str) blocks = [] for _, (res, down_rate) in enumerate(block_config): if isinstance(res, tuple): # Denotes transition to another resolution res1, res2 = res blocks.append( nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False) ) continue in_channel = channel_config[res] use_3x3 = res > 1 blocks.append( ResBlock( in_channel, int(0.5 * in_channel), in_channel, down_rate=down_rate, residual=True, use_3x3=use_3x3, ) ) # TODO: If the training is unstable try using scaling the weights self.block_mod = nn.Sequential(*blocks) # Latents self.mu = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False) self.logvar = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False) def forward(self, input, cond_vec): x = self.in_conv(input) x = self.block_mod(x) x = F.avg_pool2d(x, kernel_size=(1,6)) x = torch.cat((x, cond_vec), dim=1) return self.mu(x), self.logvar(x) class Decoder(nn.Module): def __init__(self, input_res, block_config_str, channel_config_str): super().__init__() block_config = parse_layer_string(block_config_str) channel_config = parse_channel_string(channel_config_str) blocks = [] for i, (res, up_rate) in enumerate(block_config): if isinstance(res, tuple): # Denotes transition to another resolution res1, res2 = res blocks.append( nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False) ) continue if i == 0: blocks.append(nn.Upsample(size=(1,6), mode="nearest")) if up_rate is not None: blocks.append(nn.Upsample(scale_factor=up_rate, mode="nearest")) continue in_channel = channel_config[res] use_3x3 = res > 1 blocks.append( ResBlock( in_channel, int(0.5 * in_channel), in_channel, down_rate=None, residual=True, use_3x3=use_3x3, ) ) # TODO: If the training is unstable try using scaling the weights self.block_mod = nn.Sequential(*blocks) self.last_conv = nn.Conv2d(channel_config[input_res], 1, 3, stride=1, padding=1) def forward(self, input, cond_vec): input = torch.cat((input, cond_vec), dim=1) x = self.block_mod(input) x = F.interpolate(x, size=(100, 800), mode='nearest') print(f"{x.shape =}") x = self.last_conv(x) print(f"{x.shape =}") return torch.sigmoid(x) # Implementation of the Resnet-VAE using a ResNet backbone as encoder # and Upsampling blocks as the decoder class VAE(nn.Module): def __init__( self, enc_block_str, dec_block_str, enc_channel_str, dec_channel_str, alpha=1.0, lr=1e-3, num_classes = 2 ): super().__init__() self.input_res = 128 self.enc_block_str = enc_block_str self.dec_block_str = dec_block_str self.enc_channel_str = enc_channel_str self.dec_channel_str = dec_channel_str self.alpha = alpha self.lr = lr self.num_classes = num_classes self.embedding_dim = 512 # Encoder architecture self.enc = Encoder(self.enc_block_str, self.enc_channel_str) # Decoder Architecture self.dec = Decoder(self.input_res, self.dec_block_str, self.dec_channel_str) #Conditonal Embedding self.embedding = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.embedding_dim) def encode(self, x, y): B, _, _, _ = x.shape cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1) mu, logvar = self.enc(x, cond_vec) return mu, logvar def decode(self, z, y): B, _, _, _ = z.shape cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1) return self.dec(z, cond_vec) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def compute_kl(self, mu, logvar): return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) def forward(self, x, y): mu, logvar = self.encode(x, y) z = self.reparameterize(mu, logvar) x_hat = self.decode(z, y) return x_hat, mu, logvar, z def sample(self, z, y): # Only sample during inference decoder_out = self.decode(z, y) return decoder_out def forward_recons(self, x): # For generating reconstructions during inference mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) decoder_out = self.decode(z) return decoder_out ################################################################################# def main(): NUM_CLASSES = 2 LATENT_DIM = 50 AUGMENT = False TRAIN = True EVAL = False EVAL_AND_SAVE = False FROM_PRETRAINED = False SAVE_MODEL = True SAMPLE = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") path0 = "./data/all_samples/Pkls" (train_images, train_labels), (val_images, val_labels) = get_anomaly_detection_split(path0) (train_images, train_labels), (val_images, val_labels) = get_classification_split(path0) train_images = torch.tensor(train_images[:, None, :, :], dtype=torch.float32, device=device) train_labels = torch.as_tensor(train_labels, device=device) # Not using one_hot encoding # train_labels = torch.nn.functional.one_hot( # train_labels, num_classes=NUM_CLASSES).float() if AUGMENT: train_images = augment_signal_in_ranges( train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)]) train_images = disturb_noise_augmentation( train_images, noise_level=1e-8, factor=1.2) # TODO Inv signal spike train_images = minmax_normalize(train_images) print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}") assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed." train_dataset = TensorDataset(train_images, train_labels) val_images = torch.tensor(val_images[:, None, :, :], dtype=torch.float32, device=device) val_labels = torch.as_tensor(val_labels, device=device) # Not using one_hot encoding # val_labels = torch.nn.functional.one_hot( # val_labels, num_classes=NUM_CLASSES).float() val_dataset = TensorDataset(val_images, val_labels) train_dataloader = DataLoader(train_dataset, batch_size=16) val_dataloader = DataLoader(val_dataset, batch_size=16) #################################################################################### # Code added # enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2" enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2" enc_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:512" dec_block_config_str = "1x1,1u4,1t4,4x2,4u2,4t8,8x2,8u2,8t16,16x6,16u2,16t32,32x2,32u2,32t64,64x2,64u2,64t128,128x1" dec_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:1024" vae = VAE( enc_block_config_str, dec_block_config_str, enc_channel_config_str, dec_channel_config_str, num_classes=2 ) #################################################################################### # model = SmallVAE(latent_dim=50, num_classes=10) model = vae model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) loss_history = [] kld_loss_history = [] recons_loss_history = [] epochs = 100 model.train() for epoch in range(1, epochs+1): train_loss = 0 kld_loss = 0 recons_loss = 0 for batch_idx, (data, target) in enumerate(train_dataloader): optimizer.zero_grad() data = data.to(device) target = target.to(device) # print(model.enc) recon_batch, mu, logvar, _ = model(data, target) total_loss, recons_loss, kld_loss = loss_function(recon_batch, data, mu, logvar) total_loss.backward() train_loss += total_loss kld_loss += kld_loss recons_loss += recons_loss optimizer.step() batch_loss = train_loss / len(train_dataloader.dataset) kld_loss = kld_loss / len(train_dataloader.dataset) recons_loss = recons_loss / len(train_dataloader.dataset) kld_loss_history.append(kld_loss.to("cpu").detach().numpy()) recons_loss_history.append(recons_loss.to("cpu").detach().numpy()) loss_history.append(batch_loss.to("cpu").detach().numpy()) torch.cuda.empty_cache() print(f"Epoch: {epoch}, Loss: {batch_loss.item()}, KLD: {kld_loss.item()}, RECONS: {recons_loss.item()}") # Plot loss plot_loss_hist(loss_history, "Total training loss", "loss.png") plot_loss_hist(kld_loss_history, "KLD loss", "kld_loss.png") plot_loss_hist(recons_loss_history, "Recons loss", "recons_loss.png") sampling_path = pathlib.Path("mnist_sampling") if sampling_path.exists(): shutil.rmtree(sampling_path) sampling_path.mkdir(parents=True, exist_ok=True) model.eval() num_samples = 10 latent_dim = 50 real_latent_dim = 512 num_classes = 2 all_samples = [] samples_class = [] with torch.no_grad(): for k in range(num_classes): folder_path = sampling_path / str(k) folder_path.mkdir(parents=True, exist_ok=True) # z = torch.randn(num_samples, latent_dim, 1, 1).to(device) z = torch.randn(num_samples, real_latent_dim, 1, 1).to(device) # one_hot_y = torch.zeros((num_samples, num_classes)).to(device) # one_hot_y[:, k] = 1 y = torch.full((num_samples,), k).to(device) samples = model.sample(z, y) all_samples.append(samples.to("cpu").detach().numpy()) samples_class.append(y.to("cpu").detach().numpy()) samples = samples.permute(0, 2, 3, 1).to("cpu").detach().numpy() # Save as pickle and as image for k, x_hat in enumerate(samples): # denormalize x_hat = (x_hat * 0.3081) + 0.1307 file_stem = f"{uuid.uuid4()}" plt.figure() plt.imshow(x_hat) plt.title(f"sample class {k}") # save the image image_path = folder_path / file_stem plt.savefig(str(image_path) + ".png") plt.close() all_samples = np.concatenate(all_samples, axis=0) # samples_class = np.concatenate(samples_class, axis=0) samples_class = np.concatenate(samples_class) all_samples = torch.from_numpy(all_samples).to(device) samples_class = torch.from_numpy(samples_class).to(device) recon_batch, mu, logvar, z = model(all_samples, samples_class) # targets = np.argmax(samples_class.to("cpu").numpy(), axis=1) targets = samples_class.to("cpu").numpy() # print(f"{targets = }") # print(f"{z.shape = }") N, C, *_ = z.shape tsne_z = z.squeeze(dim=(2, 3)) # reshape the tensor to have shape (N, C) tsne_z = tsne_z.view(N, C) tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) tsne_results = tsne.fit_transform(tsne_z.to("cpu").detach().numpy()) plt.figure() plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=targets) plt.title("t-SNE sampled data") plt.savefig("tsne_sampling.png") plt.close() # plot the latent space distribution of training data using t-SNE results = [] labels = [] for data, target in train_dataloader: # target = F.one_hot(target, num_classes=10) data = data.to(device) target = target.to(device) recon_batch, mu, logvar, z = model(data, target) N, C, *_ = z.shape tsne_z = z.squeeze(dim=(2, 3)) # reshape the tensor to have shape (16, 512) tsne_z = tsne_z.view(N, C) results.append(tsne_z.to("cpu").detach().numpy()) labels.append(target.to("cpu").detach().numpy()) results = np.concatenate(results, axis=0) labels = np.concatenate(labels) # labels = np.concatenate(labels, axis=0) tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) tsne_results = tsne.fit_transform(results) plt.figure() # plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=np.argmax(labels, axis=1)) plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels) plt.title("t-SNE training latent space") plt.savefig("tsne_training.png") plt.close() if __name__ == '__main__': main()