Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
17 kB
1
Indexable
Never
# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""
import pathlib
import pickle
import random
import shutil
from datetime import datetime
from typing import List, Tuple

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from transforms import (augment_signal_in_ranges, disturb_noise_augmentation,
                        minmax_normalize)
# from sklearn.manifold.t_sne import TSNE

# %% define the encoder


class Encoder(torch.nn.Module):
    def __init__(self, latent_dim: int, num_classes: int):
        super(Encoder, self).__init__()
        self.cnn1 = torch.nn.Conv2d(1, 20, (4, 32), (2, 16))
        self.cnn2 = torch.nn.Conv2d(20, 50, (5, 5), (2, 2))
        self.cnn3 = torch.nn.Conv2d(50, 100, (5, 5), (2, 2))
        self.cnn4 = torch.nn.Conv2d(100, 200, (4, 4), (2, 2))
        self.cnn5 = torch.nn.Conv2d(200, 500, (4, 4), (1, 1))
        # cat class label here for conditional VAE
        self.linear1 = torch.nn.Linear(500 + num_classes, latent_dim)
        self.linear2 = torch.nn.Linear(500 + num_classes, 1)

    def forward(self, x, y):
        """
        x: torch.Tensor
            The input tensor. The tensor should have shape (batch_size, channels, rows, columns).
        y: torch.Tensor 
            The class label. The tensor should have shape (batch_size, num_classes).
        """
        x = -torch.log(1/(0.001 + 0.998*x) - 1)
        #print(x.shape, torch.std(x))

        x = self.cnn1(x/2)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnn2(x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnn3(2*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnn4(3*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnn5(2*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = torch.flatten(x, 1, 3)

        x = torch.cat((x, y), dim=1)
        mu = self.linear1(x)
        sigma = torch.exp(self.linear2(x))

        # return mean and variance**0.5
        return (mu, sigma)


# %% define the decoder
class Decoder(torch.nn.Module):
    def __init__(self, latent_dim: int, num_classes: int):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(
            latent_dim + num_classes, 500)  # cat class label here
        self.cnnt1 = torch.nn.ConvTranspose2d(500, 200, (4, 4), (1, 1))
        self.cnnt2 = torch.nn.ConvTranspose2d(200, 100, (4, 4), (2, 2))
        self.cnnt3 = torch.nn.ConvTranspose2d(100, 50, (5, 5), (2, 2))
        self.cnnt4 = torch.nn.ConvTranspose2d(50, 20, (5, 5), (2, 2))
        self.cnnt5 = torch.nn.ConvTranspose2d(20, 1, (4, 32), (2, 16))

    def forward(self, x, y):
        """
        x: torch.Tensor
            The latent vector. The tensor should have shape (batch_size, latent_dim).
        y: torch.Tensor
            The class label. The tensor should have shape (batch_size, num_classes).
        """
        x = torch.cat((x, y), dim=1)
        x = self.linear1(5*x)
        #print(x.shape, torch.std(x))

        x = self.cnnt1(5*x[:, :, None, None])
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnnt2(3*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnnt3(5*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnnt4(3*x)
        #print(x.shape, torch.std(x))
        x = torch.max(x, 0.1*x)

        x = self.cnnt5(5*x)
        #print(x.shape, torch.std(x))

        mu = torch.sigmoid(x - 5)

        # always assume variance=1. so only return the mean
        return mu


# %% define the classifier
class Classifier(torch.nn.Module):
    def __init__(self, num_classes: int):
        super(Classifier, self).__init__()
        # as a binary classification problem, the output dim is 1!
        self.linear1 = torch.nn.Linear(50, num_classes)
        self.activation = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        return x


# %% define the dataset
VAL_FRAC = 0.2


def get_anomaly_detection_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
    """
    Train split contains 80% Negative samples and val split contains all positives and the 20% of negatives
    """
    files = os.listdir(dir_path)
    total_negative = 0
    total_positive = 0
    train_images = []
    train_labels = []
    val_images = []
    val_labels = []
    for file in files:
        with open(os.path.join(dir_path, file), 'rb') as f:
            sample = pickle.load(f)
            if sample.get("label_name") == "NEGATIVE":
                total_negative += 1
                if random.random() < VAL_FRAC:
                    val_images.append(sample['data'])
                    val_labels.append(sample['label_idx'])
                else:
                    train_images.append(sample['data'])
                    train_labels.append(sample['label_idx'])
            else:  # All Positive samples go to validation
                total_positive += 1
                val_images.append(sample['data'])
                val_labels.append(sample['label_idx'])
    print('number of train images: {}'.format(len(train_images)))
    print('number of val images: {}'.format(len(val_images)))
    for image in train_images:
        print('image size: {}; min: {}; max: {}'.format(
            image.shape, np.min(image), np.max(image)))

    train_images = np.stack(train_images, axis=0)
    val_images = np.stack(val_images, axis=0)
    return (train_images, train_labels), (val_images, val_labels)


def get_classification_split(dir_path: str) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
    """
    Train split contains 80% of all samples and val split contains the 20% of all samples
    """
    from sklearn.model_selection import train_test_split
    total_negative = 0
    total_positive = 0
    train_images = []
    train_labels = []
    val_images = []
    val_labels = []
    X, y = [], []
    files = os.listdir(dir_path)
    for file in files:
        with open(os.path.join(dir_path, file), "rb") as f:
            sample = pickle.load(f)
            X.append(sample['data'])
            y.append(sample['label_idx'])
    train_images, val_images, train_labels, val_labels = train_test_split(
        X, y, test_size=VAL_FRAC, random_state=42)
    train_images = np.stack(train_images, axis=0)
    val_images = np.stack(val_images, axis=0)
    train_labels = np.stack(train_labels, axis=0)
    val_labels = np.stack(val_labels, axis=0)
    print('Num train samples: {}'.format(len(train_images)))
    print('Num val samples: {}'.format(len(val_images)))
    return (train_images, train_labels), (val_images, val_labels)


if __name__ == "__main__":
    import os
    import pickle
    import uuid

    import matplotlib.pyplot as plt
    import numpy as np

    # %% load images
    NUM_CLASSES = 2
    LATENT_DIM = 50
    AUGMENT = False
    TRAIN = True
    EVAL = False
    EVAL_AND_SAVE = False
    FROM_PRETRAINED = False
    SAVE_MODEL = True
    SAMPLE = True
    # path0 = 'C:/Users/16692/Downloads/Pkls'
    path0 = "./data/all_samples/Pkls"
    # send images, labels and models to cuda
    device = torch.device('cuda')

    # (train_images, train_labels), (val_images,
    #                                val_labels) = get_anomaly_detection_split(path0)
    (train_images, train_labels), (val_images,
                                   val_labels) = get_classification_split(path0)
    train_images = torch.tensor(train_images[:, None, :, :],
                                dtype=torch.float32, device=device)
    train_labels = torch.as_tensor(train_labels, device=device)
    train_labels = torch.nn.functional.one_hot(
        train_labels, num_classes=NUM_CLASSES).float()

    if AUGMENT:
        train_images = augment_signal_in_ranges(
            train_images, noise_level=1e-8, aug_factor=1.5, ranges=[(50, 300)])
        train_images = disturb_noise_augmentation(
            train_images, noise_level=1e-8, factor=1.2)
        # TODO Inv signal spike

    train_images = minmax_normalize(train_images)
    print(f"Mean: {torch.mean(train_images):.10f}, Std: {torch.std(train_images):.10f}, Min: {torch.min(train_images)}, Max: {torch.max(train_images):.5f}")
    assert train_images.min() >= 0.0 or train_images.max() <= 1.0, "Normalization failed."

    train_dataset = TensorDataset(train_images, train_labels)

    val_images = torch.tensor(val_images[:, None, :, :],
                              dtype=torch.float32, device=device)
    val_labels = torch.as_tensor(val_labels, device=device)
    val_labels = torch.nn.functional.one_hot(
        val_labels, num_classes=NUM_CLASSES).float()

    val_dataset = TensorDataset(val_images, val_labels)

    train_dataloader = DataLoader(train_dataset, batch_size=32)
    val_dataloader = DataLoader(val_dataset, batch_size=32)

    encoder = Encoder(latent_dim=LATENT_DIM,
                      num_classes=NUM_CLASSES).to(device)
    decoder = Decoder(latent_dim=LATENT_DIM,
                      num_classes=NUM_CLASSES).to(device)
    classifier = Classifier(num_classes=NUM_CLASSES).to(device)
    threshold = None

    if FROM_PRETRAINED:
        checkpoint = torch.load(
            "./trained_models/model_04-04-2023_13:15:42.pth")
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        classifier.load_state_dict(checkpoint['classifier_state_dict'])
        threshold = checkpoint['threshold']
    # %% use the adam optimizer
    if TRAIN:
        print("Training job")
        opt = torch.optim.Adam(list(encoder.parameters()) +
                               list(decoder.parameters()) +
                               list(classifier.parameters()),
                               lr=1e-4)

        # %% training loops begin here
        num_iters = 30
        Loss = []
        encoder.train()
        decoder.train()
        classifier.train()
        threshold = torch.zeros(len(train_images)).to(device)
        l1_loss = torch.nn.L1Loss()
        reconstruction_losses = []
        for i in range(num_iters):

            def total_loss(images, labels):
                # encoder loss
                x = images
                y_true = labels
                mu, sigma = encoder(x, y_true)
                # KL divergence loss
                divergence = torch.sum(
                    0.5*mu**2 + 0.5*sigma**2 - torch.log(sigma))

                # classification loss
                # or first sample to get z; then classify with z
                y_pred = classifier(mu)
                xentropy = torch.mean(
                    torch.log(1 + torch.exp(-y_true*torch.squeeze(y_pred))))

                # decoder loss
                z = mu + sigma*torch.randn_like(mu)
                mu = decoder(z, y_true)
                distortion = torch.sum(0.5*(x - mu)**2)
                reconstruction_losses.append(distortion.detach().cpu())
                # assign proper weights to each loss, and return their sum
                return 1.0*(divergence + distortion) + 1.0*xentropy

            for batch_data, batch_targets in train_dataloader:
                opt.zero_grad()
                loss = total_loss(batch_data, batch_targets)
                loss.backward()
                opt.step()
            Loss.append(loss.item())
            print('iter: {}; loss: {}'.format(i, Loss[-1]))
        # Calculating the threshold
        # for sample_data, target_data in train_dataloader:
        #     print(sample_data.shape)
        #     print(target_data.shape)
        #     for idx, sample in enumerate(sample_data):
        #         sample = torch.unsqueeze(sample, axis=0)
        #         print(sample.shape)
        #         target = target_data[idx, : ]
        #         mu, _ = encoder(sample, target)
        #         mu = decoder(mu, target)
        #         l1_loss_result = l1_loss(sample, mu)
        #         # reconstruction_losses.append(l1_loss_result.detach().cpu())
        #         threshold =  torch.maximum(threshold, l1_loss_result)

        plt.plot(Loss)
        # plt.hist(reconstruction_losses, bins=50)
        plt.savefig("loss.png")
        plt.show()
        if SAVE_MODEL:
            now = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
            torch.save({
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'threshold': threshold
            }, f"./trained_models/model_{now}.pth")

    # Evaluating the reconstruction loss
    if EVAL:
        print("Evaluating")
        encoder.eval()
        decoder.eval()
        classifier.eval()
        negative_reconstruction_losses = []
        positive_reconstruction_losses = []
        with torch.no_grad():
            l1_loss = torch.nn.L1Loss()

            mu, _ = encoder(val_images, val_labels)
            mu = decoder(mu, val_labels)
            for i in range(len(val_images)):
                assert val_images[i].shape == mu[i].shape
                output = l1_loss(val_images[i], mu[i])
                output = output.detach().cpu()
                if val_labels[i] == 0:
                    positive_reconstruction_losses.append(output)
                else:
                    negative_reconstruction_losses.append(output)

            plt.hist(negative_reconstruction_losses, bins=50, label='negative')
            plt.hist(positive_reconstruction_losses, bins=50, label='positive')
            plt.legend(loc='upper right')
            plt.savefig("loss_reconstruction.png")
    # %% let's check the reconstructed images
    if EVAL_AND_SAVE:
        print("Evaluating")
        encoder.eval()
        decoder.eval()
        classifier.eval()
        with torch.no_grad():
            mu, _ = encoder(val_images, val_labels)
            y = classifier(mu)

            mu = decoder(mu, val_labels)
            y = y.detach().cpu().numpy()
            mu = mu.detach().cpu().numpy()
            images0 = val_images.detach().cpu().numpy()
            for i in range(len(val_images)):
                title = f"reconstructed - Classification {str(float(y[0]))} \n original class {val_labels[i]}"
                plt.subplot(121)
                plt.imshow(mu[i, 0])
                plt.title(title)
                plt.subplot(122)
                plt.imshow(images0[i, 0])
                plt.title("original")
                plt.show()
                plt.savefig(
                    f"reconstructed{os.sep}{uuid.uuid4()}_{val_labels[i]}.png")

    if SAMPLE:
        print("Sampling")
        sampling_path = pathlib.Path("gen_samples")
        if sampling_path.exists():
            shutil.rmtree(sampling_path)
        encoder.eval()
        decoder.eval()
        classifier.eval()
        with torch.no_grad():
            for k in range(NUM_CLASSES):
                folder_path = sampling_path / f"class_{k}"
                plot_path = folder_path / "plots"
                pkls_path = folder_path / "pkls"
                plot_path.mkdir(parents=True, exist_ok=True)
                pkls_path.mkdir(parents=True, exist_ok=True)
                for i in range(30):
                    z = torch.randn(1, LATENT_DIM).to(device)
                    one_hot_y = torch.zeros(1, NUM_CLASSES).to(device)
                    one_hot_y[0, k] = 1
                    x_hat_tensor = decoder(z, one_hot_y).permute(
                        0, 2, 3, 1).detach().cpu().numpy()
                    x_hat = np.squeeze(x_hat_tensor, axis=0)
                    # Save as pickle and as image
                    file_stem = f"{uuid.uuid4()}_class_{k}"
                    pkl_path = pkls_path / file_stem
                    with open(str(pkl_path) + ".pkl", "wb") as f:
                        pickle.dump(x_hat, f)
                    plt.figure()
                    plt.imshow(x_hat)
                    plt.title(f"sample class {k}")
                    # save the image
                    image_path = plot_path / file_stem
                    plt.savefig(str(image_path) + ".png")
                    
                    x_hat_tensor = torch.from_numpy(x_hat_tensor)
                    mu, sigma = encoder(x_hat_tensor, one_hot_y)