Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
21 kB
1
Indexable
Never
import os
import pathlib
import pickle
import random
import shutil
import uuid
from datetime import datetime
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from datasets import get_anomaly_detection_split, get_classification_split
from transforms import augment_signal_in_ranges, disturb_noise_augmentation, minmax_normalize

class SmallVAE(nn.Module):
    def __init__(self, latent_dim=50, num_classes=2):
        super(SmallVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        # Encoder
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(32*7*7 + num_classes, 500)
        self.fc2_mean = nn.Linear(500, latent_dim)
        self.fc2_logvar = nn.Linear(500, latent_dim)

        # Decoder
        self.fc3 = nn.Linear(latent_dim + num_classes, 32*7*7)
        self.conv3 = nn.ConvTranspose2d(
            32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv4 = nn.ConvTranspose2d(
            16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def encode(self, x, y):
        # print(f"{x.shape =}")
        x = F.relu(self.conv1(x))
        # print(x.shape)
        x = F.relu(self.conv2(x))
        # print(x.shape)
        x = x.view(x.size(0), -1)
        # print(x.shape)
        x = torch.cat((x, y), dim=1)
        # print(x.shape)
        x = F.relu(self.fc1(x))
        # print(x.shape)
        z_mean = self.fc2_mean(x)
        z_logvar = self.fc2_logvar(x)
        # print(f"{z_mean.shape =}")
        # print(f"{z_logvar.shape =}")
        return z_mean, z_logvar

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        z = z_mean + eps * std
        return z

    def decode(self, z, y):
        # print(f"{z.shape =}")
        # print(f"{y.shape =}")
        z_cond = torch.cat((z, y), dim=1)
        # print(f"{z_cond.shape =}")
        x = F.relu(self.fc3(z_cond))
        # print(f"{x.shape =}")
        x = x.view(x.size(0), 32, 7, 7)
        # print(f"{x.shape =}")
        x = F.relu(self.conv3(x))
        # print(f"{x.shape =}")
        x = torch.sigmoid(self.conv4(x))
        # print(f"{x.shape =}")
        return x

    def forward(self, x, y):
        z_mean, z_logvar = self.encode(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_hat = self.decode(z, y)
        return x_hat, z_mean, z_logvar, z

    def sample(self, z, y):
        with torch.no_grad():
            x_hat = self.decode(z, y)
        return x_hat

def compute_kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

def loss_function(x_recon, x, mu, logvar, alpha=1.0):
    # print(f"{x_recon.shape =}")
    # print(f"{x.shape =}")
    mse_loss = F.mse_loss(x_recon, x, reduction="sum")
    # bce_loss = F.binary_cross_entropy(x_recon, x, reduction='mean')
    kld_loss = compute_kl_loss(mu, logvar)
    # print(f"{mse_loss = }")
    # print(f"{kld_loss = }")
    total_loss = mse_loss + alpha * kld_loss
    return total_loss, mse_loss, kld_loss

def plot_loss_hist(hist, title, filename):
    plt.figure()
    plt.plot(hist)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig(filename)

#################################################################################
# Code added

def parse_layer_string(s):
    layers = []
    for ss in s.split(","):
        if "x" in ss:
            # Denotes a block repetition operation
            res, num = ss.split("x")
            count = int(num)
            layers += [(int(res), None) for _ in range(count)]
        elif "u" in ss:
            # Denotes a resolution upsampling operation
            res, mixin = [int(a) for a in ss.split("u")]
            layers.append((res, mixin))
        elif "d" in ss:
            # Denotes a resolution downsampling operation
            res, down_rate = [int(a) for a in ss.split("d")]
            layers.append((res, down_rate))
        elif "t" in ss:
            # Denotes a resolution transition operation
            res1, res2 = [int(a) for a in ss.split("t")]
            layers.append(((res1, res2), None))
        else:
            res = int(ss)
            layers.append((res, None))
    return layers


def parse_channel_string(s):
    channel_config = {}
    for ss in s.split(","):
        res, in_channels = ss.split(":")
        channel_config[int(res)] = int(in_channels)
    return channel_config


def get_conv(
    in_dim,
    out_dim,
    kernel_size,
    stride,
    padding,
    zero_bias=True,
    zero_weights=False,
    groups=1,
):
    c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups)
    if zero_bias:
        c.bias.data *= 0.0
    if zero_weights:
        c.weight.data *= 0.0
    return c


def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1):
    return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups)


def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1):
    return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups)


class ResBlock(nn.Module):
    def __init__(
        self,
        in_width,
        middle_width,
        out_width,
        down_rate=None,
        residual=False,
        use_3x3=True,
        zero_last=False,
    ):
        super().__init__()
        self.down_rate = down_rate
        self.residual = residual
        self.c1 = get_1x1(in_width, middle_width)
        self.c2 = (
            get_3x3(middle_width, middle_width)
            if use_3x3
            else get_1x1(middle_width, middle_width)
        )
        self.c3 = (
            get_3x3(middle_width, middle_width)
            if use_3x3
            else get_1x1(middle_width, middle_width)
        )
        self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last)

    def forward(self, x):
        xhat = self.c1(F.gelu(x))
        xhat = self.c2(F.gelu(xhat))
        xhat = self.c3(F.gelu(xhat))
        xhat = self.c4(F.gelu(xhat))
        out = x + xhat if self.residual else xhat
        if self.down_rate is not None:
            out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate)
        return out


class Encoder(nn.Module):
    def __init__(self, block_config_str, channel_config_str):
        super().__init__()
        
        self.in_conv = nn.Conv2d(1, 64, 3, stride=1, padding=1, bias=False)

        block_config = parse_layer_string(block_config_str)
        channel_config = parse_channel_string(channel_config_str)
        blocks = []
        for _, (res, down_rate) in enumerate(block_config):
            if isinstance(res, tuple):
                # Denotes transition to another resolution
                res1, res2 = res
                blocks.append(
                    nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False)
                )
                continue
            in_channel = channel_config[res]
            use_3x3 = res > 1
            blocks.append(
                ResBlock(
                    in_channel,
                    int(0.5 * in_channel),
                    in_channel,
                    down_rate=down_rate,
                    residual=True,
                    use_3x3=use_3x3,
                )
            )
        # TODO: If the training is unstable try using scaling the weights
        self.block_mod = nn.Sequential(*blocks)

        # Latents
        self.mu = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False)
        self.logvar = nn.Conv2d(channel_config[1]*2, channel_config[1], 1, bias=False)

    def forward(self, input, cond_vec):
        x = self.in_conv(input)
        x = self.block_mod(x)
        x = F.avg_pool2d(x, kernel_size=(1,6))
        x = torch.cat((x, cond_vec), dim=1)

        return self.mu(x), self.logvar(x)


class Decoder(nn.Module):
    def __init__(self, input_res, block_config_str, channel_config_str):
        super().__init__()
        block_config = parse_layer_string(block_config_str)
        channel_config = parse_channel_string(channel_config_str)
        blocks = []
        for i, (res, up_rate) in enumerate(block_config):
            if isinstance(res, tuple):
                # Denotes transition to another resolution
                res1, res2 = res
                blocks.append(
                    nn.Conv2d(channel_config[res1], channel_config[res2], 1, bias=False)
                )
                continue

            if i == 0:
                  blocks.append(nn.Upsample(size=(1,6), mode="nearest"))

            if up_rate is not None:
                blocks.append(nn.Upsample(scale_factor=up_rate, mode="nearest"))
                continue

            in_channel = channel_config[res]
            use_3x3 = res > 1
            blocks.append(
                ResBlock(
                    in_channel,
                    int(0.5 * in_channel),
                    in_channel,
                    down_rate=None,
                    residual=True,
                    use_3x3=use_3x3,
                )
            )
        # TODO: If the training is unstable try using scaling the weights
        self.block_mod = nn.Sequential(*blocks)
        self.last_conv = nn.Conv2d(channel_config[input_res], 1, 3, stride=1, padding=1)

    def forward(self, input, cond_vec):
        input = torch.cat((input, cond_vec), dim=1)
        x = self.block_mod(input)
        x = F.interpolate(x, size=(100, 800), mode='nearest')
        print(f"{x.shape =}")
        x = self.last_conv(x)
        print(f"{x.shape =}")
        return torch.sigmoid(x)


# Implementation of the Resnet-VAE using a ResNet backbone as encoder
# and Upsampling blocks as the decoder
class VAE(nn.Module):
    def __init__(
        self,
        enc_block_str,
        dec_block_str,
        enc_channel_str,
        dec_channel_str,
        alpha=1.0,
        lr=1e-3,
        num_classes = 2
    ):
        super().__init__()
        self.input_res = 128
        self.enc_block_str = enc_block_str
        self.dec_block_str = dec_block_str
        self.enc_channel_str = enc_channel_str
        self.dec_channel_str = dec_channel_str
        self.alpha = alpha
        self.lr = lr
        self.num_classes = num_classes
        self.embedding_dim = 512

        # Encoder architecture
        self.enc = Encoder(self.enc_block_str, self.enc_channel_str)

        # Decoder Architecture
        self.dec = Decoder(self.input_res, self.dec_block_str, self.dec_channel_str)

        #Conditonal Embedding 
        self.embedding = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.embedding_dim)

    def encode(self, x, y):
        B, _, _, _ = x.shape
        cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1)
        mu, logvar = self.enc(x, cond_vec)
        return mu, logvar

    def decode(self, z, y):
        B, _, _, _ = z.shape
        cond_vec = self.embedding(y).squeeze(1).reshape(B, self.embedding_dim, 1, 1)
        return self.dec(z, cond_vec)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def compute_kl(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z, y)
        return x_hat, mu, logvar, z
    
    def sample(self, z, y):
        # Only sample during inference
        decoder_out = self.decode(z, y)
        return decoder_out

    def forward_recons(self, x):
        # For generating reconstructions during inference
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        decoder_out = self.decode(z)
        return decoder_out



#################################################################################


def main():
    NUM_CLASSES = 2
    LATENT_DIM = 50
    AUGMENT = False
    TRAIN = True
    EVAL = False
    EVAL_AND_SAVE = False
    FROM_PRETRAINED = False
    SAVE_MODEL = True
    SAMPLE = True
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path0 = "./data/all_samples/Pkls"

    (train_images, train_labels), (val_images,
                                   val_labels) = get_anomaly_detection_split(path0)
    (train_images, train_labels), (val_images,
                                   val_labels) = get_classification_split(path0)
    train_images = torch.tensor(train_images[:, None, :, :],
                                dtype=torch.float32, device=device)
    train_labels = torch.as_tensor(train_labels, device=device)
    # Not using one_hot encoding
    # train_labels = torch.nn.functional.one_hot(
    #     train_labels, num_classes=NUM_CLASSES).float()

    if AUGMENT:
        train_images = augment_signal_in_ranges(
            train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)])
        train_images = disturb_noise_augmentation(
            train_images, noise_level=1e-8, factor=1.2)
        # TODO Inv signal spike

    train_images = minmax_normalize(train_images)
    print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}")
    assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed."

    train_dataset = TensorDataset(train_images, train_labels)

    val_images = torch.tensor(val_images[:, None, :, :],
                              dtype=torch.float32, device=device)
    val_labels = torch.as_tensor(val_labels, device=device)
    # Not using one_hot encoding
    # val_labels = torch.nn.functional.one_hot(
    #     val_labels, num_classes=NUM_CLASSES).float()

    val_dataset = TensorDataset(val_images, val_labels)

    train_dataloader = DataLoader(train_dataset, batch_size=16)
    val_dataloader = DataLoader(val_dataset, batch_size=16)
    
    
    ####################################################################################
    # Code added
    # enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2"
    enc_block_config_str = "128x1,128d2,128t64,64x3,64d2,64t32,32x3,32d2,32t16,16x7,16d2,16t8,8x3,8d2,8t4,4x3,4d3,4t1,1x2"
    enc_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:512"

    dec_block_config_str = "1x1,1u4,1t4,4x2,4u2,4t8,8x2,8u2,8t16,16x6,16u2,16t32,32x2,32u2,32t64,64x2,64u2,64t128,128x1"
    dec_channel_config_str = "128:64,64:64,32:128,16:128,8:256,4:512,1:1024"
    vae = VAE(
        enc_block_config_str,
        dec_block_config_str,
        enc_channel_config_str,
        dec_channel_config_str,
        num_classes=2
    )
    ####################################################################################

    # model = SmallVAE(latent_dim=50, num_classes=10)
    model = vae
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    loss_history = []
    kld_loss_history = []
    recons_loss_history = []
    epochs = 100
    model.train()
    for epoch in range(1, epochs+1):
        train_loss = 0
        kld_loss = 0
        recons_loss = 0
        for batch_idx, (data, target) in enumerate(train_dataloader):
            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            # print(model.enc)
            recon_batch, mu, logvar, _ = model(data, target)
            total_loss, recons_loss, kld_loss = loss_function(recon_batch, data, mu, logvar)
            total_loss.backward()
            train_loss += total_loss 
            kld_loss += kld_loss
            recons_loss += recons_loss 
            optimizer.step()
        batch_loss = train_loss / len(train_dataloader.dataset)
        kld_loss = kld_loss / len(train_dataloader.dataset)
        recons_loss = recons_loss / len(train_dataloader.dataset)
        kld_loss_history.append(kld_loss.to("cpu").detach().numpy())
        recons_loss_history.append(recons_loss.to("cpu").detach().numpy())
        loss_history.append(batch_loss.to("cpu").detach().numpy())
        torch.cuda.empty_cache()
        print(f"Epoch: {epoch}, Loss: {batch_loss.item()}, KLD: {kld_loss.item()}, RECONS: {recons_loss.item()}")

    # Plot loss
    plot_loss_hist(loss_history, "Total training loss", "loss.png")
    plot_loss_hist(kld_loss_history, "KLD loss", "kld_loss.png")
    plot_loss_hist(recons_loss_history, "Recons loss", "recons_loss.png")

    sampling_path = pathlib.Path("mnist_sampling")
    if sampling_path.exists():
        shutil.rmtree(sampling_path)
    sampling_path.mkdir(parents=True, exist_ok=True)
    model.eval()
    num_samples = 10  
    latent_dim = 50
    real_latent_dim = 512
    num_classes = 2
    all_samples = []
    samples_class = []
    with torch.no_grad():
        for k in range(num_classes):
            folder_path = sampling_path / str(k)
            folder_path.mkdir(parents=True, exist_ok=True)
            # z = torch.randn(num_samples, latent_dim, 1, 1).to(device)
            z = torch.randn(num_samples, real_latent_dim, 1, 1).to(device)
            # one_hot_y = torch.zeros((num_samples, num_classes)).to(device)
            # one_hot_y[:, k] = 1
            y = torch.full((num_samples,), k).to(device)
            samples = model.sample(z, y)
            all_samples.append(samples.to("cpu").detach().numpy()) 
            samples_class.append(y.to("cpu").detach().numpy())
            samples = samples.permute(0, 2, 3, 1).to("cpu").detach().numpy()
            # Save as pickle and as image
            for k, x_hat in enumerate(samples):
                # denormalize
                x_hat = (x_hat * 0.3081) + 0.1307
                file_stem = f"{uuid.uuid4()}"
                plt.figure()
                plt.imshow(x_hat)
                plt.title(f"sample class {k}")
                # save the image
                image_path = folder_path / file_stem
                plt.savefig(str(image_path) + ".png")
                plt.close()
        all_samples = np.concatenate(all_samples, axis=0)
        # samples_class = np.concatenate(samples_class, axis=0)
        samples_class = np.concatenate(samples_class)
        all_samples = torch.from_numpy(all_samples).to(device)
        samples_class = torch.from_numpy(samples_class).to(device)
        recon_batch, mu, logvar, z = model(all_samples, samples_class)
        # targets = np.argmax(samples_class.to("cpu").numpy(), axis=1)
        targets = samples_class.to("cpu").numpy()
        # print(f"{targets = }")
        # print(f"{z.shape = }")
        N, C, *_ = z.shape
        tsne_z = z.squeeze(dim=(2, 3))

        # reshape the tensor to have shape (N, C)
        tsne_z = tsne_z.view(N, C)
        tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
        tsne_results = tsne.fit_transform(tsne_z.to("cpu").detach().numpy())
        plt.figure()
        plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=targets)
        plt.title("t-SNE sampled data")
        plt.savefig("tsne_sampling.png")
        plt.close()
    # plot the latent space distribution of training data using t-SNE
        results = []
        labels = []
        for data, target in train_dataloader:
            # target = F.one_hot(target, num_classes=10)
            data = data.to(device)
            target = target.to(device)
            recon_batch, mu, logvar, z = model(data, target)
            N, C, *_ = z.shape
            tsne_z = z.squeeze(dim=(2, 3))

            # reshape the tensor to have shape (16, 512)
            tsne_z = tsne_z.view(N, C)
            results.append(tsne_z.to("cpu").detach().numpy())
            labels.append(target.to("cpu").detach().numpy())
        results = np.concatenate(results, axis=0)
        labels = np.concatenate(labels)
        # labels = np.concatenate(labels, axis=0)
        tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
        tsne_results = tsne.fit_transform(results)
        plt.figure()
        # plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=np.argmax(labels, axis=1))
        plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels)
        plt.title("t-SNE training latent space")
        plt.savefig("tsne_training.png")
        plt.close()





if __name__ == '__main__':
    main()