# -*- coding: utf-8 -*-
from google.colab import drive

drive_root = "/content/drive/MyDrive/Task3"
# %cd $drive_root

# Commented out IPython magic to ensure Python compatibility.
# %cd $drive_root

# This serves as a template which will guide you through the implementation of this task.  It is advised
# to first read the whole template and get a sense of the overall structure of the code before trying to fill in any of the TODO gaps
# First, we import necessary libraries:
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
import os
import torch
from torchvision import transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def generate_embeddings():
    Transform, resize and normalize the images and then use a pretrained model to extract
    the embeddings.
    # We transform the images to the expected input format 380x380, and with the expected normalization for the pretrained model
    train_transforms = transforms.Compose([
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
    train_dataset = datasets.ImageFolder(root="dataset/", transform=train_transforms)
    # Hint: adjust batch_size and num_workers to your PC configuration, so that you don't
    # run out of memory
    train_loader = DataLoader(dataset=train_dataset,
                              pin_memory=True, num_workers=2)

    # we use pretrained model EfficientNet_B0
    weights = models.EfficientNet_B4_Weights.DEFAULT
    model = models.efficientnet_b4(weights=weights)

    # freeze the base layers
    for param in model.features.parameters():
        param.requires_grad = False

    # remove last layers
    newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))

    y = torch.zeros(0, 1792).to(device)  # empty tensor that will include all the embeddings
    for batch, (X) in enumerate(train_loader):
        y_c = newmodel(X[0])[:, :, 0, 0]
        y = torch.cat((y, y_c.to(device)))
        if (batch + 1) % 5 == 0:
            print((batch + 1) * 64)

    embedding_size = 1792  # number of the outputs of the pretrained model after removing last layers
    # pick your model
    num_images = len(train_dataset)

    embeddings = y.cpu().numpy()

    np.save('dataset/embeddings.npy', embeddings)

def get_data(file, train=True):
    Load the triplets from the file and generate the features and labels.

    input: file: string, the path to the file containing the triplets
          train: boolean, whether the data is for training or testing

    output: X: numpy array, the features
            y: numpy array, the labels
    triplets = []
    with open(file) as f:
        for line in f:

    # generate training data from triplets
    train_dataset = datasets.ImageFolder(root="dataset/",
    filenames = [s[0].split('/')[-1].replace('.jpg', '') for s in train_dataset.samples]
    embeddings = np.load('dataset/embeddings.npy')

    # We normalize the embeddings across the dataset
    scaler = MinMaxScaler()  # because the embeddings don't have a gaussian distribution, we use minmax
    embeddings = scaler.fit_transform(embeddings)

    file_to_embedding = {}
    for i in range(len(filenames)):
        file_to_embedding[filenames[i]] = embeddings[i]

    if train:
        triplets, val_triplets = train_test_split(triplets, test_size=0.2, random_state=1)  # 20 percent for validation
        triplets, test_triplets = train_test_split(triplets, test_size=0.125,
                                                   random_state=1)  # 10 percent for final test

    X = []
    y = []
    X_v = []
    y_v = []
    X_t = []
    y_t = []
    # use the individual embeddings to generate the features and labels for triplets
    for t in triplets:
        emb = [a for a in t.split()]
        X.append(np.hstack([emb[0], emb[1], emb[2]]))
        # Generating negative samples (data augmentation)
        if train:
            X.append(np.hstack([emb[0], emb[2], emb[1]]))
    if train:
        for t in val_triplets:
            # emb = [a for a in t.split()]
            # X_v.append(np.hstack([emb[0], emb[1], emb[2]]))
            # y_v.append(1)
            # # Generating negative samples (data augmentation)
            # X_v.append(np.hstack([emb[0], emb[2], emb[1]]))
            # y_v.append(0)

        for t in test_triplets:
            # emb = [a for a in t.split()]
            # X_t.append(np.hstack([emb[0], emb[1], emb[2]]))
            # y_t.append(1)
            # # Generating negative samples (data augmentation)
            # X_t.append(np.hstack([emb[0], emb[2], emb[1]]))
            # y_t.append(0)

    return set(X), set(X_v), set(X_t)

    X = np.vstack(X)
    y = np.hstack(y)

    if train:
        X_v = np.vstack(X_v)
        y_v = np.hstack(y_v)
        X_t = np.vstack(X_t)
        y_t = np.hstack(y_t)

        return X, y, X_v, y_v, X_t, y_t

        return X, y

# Hint: adjust batch_size and num_workers to your PC configuration, so that you don't run out of memory
def create_loader_from_np(X, y=None, train=True, batch_size=64, shuffle=True, num_workers=2):
    Create a torch.utils.data.DataLoader object from numpy arrays containing the data.

    input: X: numpy array, the features
           y: numpy array, the labels

    output: loader: torch.data.util.DataLoader, the object containing the data
    if train:
        dataset = TensorDataset(torch.from_numpy(X).type(torch.float),
        dataset = TensorDataset(torch.from_numpy(X).type(torch.float))
    loader = DataLoader(dataset=dataset,
                        pin_memory=True, num_workers=num_workers)
    return loader

# TODO: define a model. Here, the basic structure is defined, but you need to fill in the details
class Net(nn.Module):
    The model class, which defines our classifier.

    def __init__(self, in_features=5376, h1=2048, h2=128, out_features=1, dropout=1):
        The constructor of the model.
        self.fc1 = nn.Linear(in_features, h1)
        self.out = nn.Linear(h1, out_features)
        self.dropout1 = nn.Dropout(0.2 * d)
        self.dropout2 = nn.Dropout(0.5 * d)
        self.dropout3 = nn.Dropout(0.2 * d)

    def forward(self, x):
        The forward pass of the model.
        input: x: torch.Tensor, the input to the model
        output: x: torch.Tensor, the output of the model
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.out(x)
        x = F.sigmoid(x)
        return x

def train_model(train_loader, val_loader, fval_loader, alpha, wd, d):
    The training procedure of the model; it accepts the training data, defines the model
    and then trains it.

    input: train_loader: torch.data.util.DataLoader, the object containing the training data

    output: model: torch.nn.Module, the trained model
    model = Net(dropout=d)
    n_epochs = 100

    criterion = nn.BCELoss()  # we use cross etropy loss
    # optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=0.00001)
    optimizer = torch.optim.SGD(model.parameters(), lr=alpha, momentum=0.9, weight_decay=wd)

    train_loss = []
    val_loss = []

    for epoch in range(n_epochs):

        T_loss = []  # train loss list over epoch
        V_loss = []  # validation loss over epoch
        # i=0
        for [X, y] in train_loader:
            y_pred = torch.flatten(model.forward(X.to(device)))
            # if i==0:
            #     print(y_pred[:50])
            #     print(y[:50])
            #     i=1
            loss = criterion(y_pred, y.float().to(device))

        for [X, y] in val_loader:
            y_pred = torch.flatten(model.forward(X.to(device)))
            lossv = criterion(y_pred, y.float().to(device))

        T_loss = np.mean(T_loss)
        V_loss = np.mean(V_loss)

        print(f'Epoch {epoch} and train_loss is: {T_loss}')
        print(f'Epoch {epoch} and val_loss is: {V_loss}')
        if epoch % 5 == 4:
            val_model(model, fval_loader)

    return model

def val_model(model, loader):
    The testing procedure of the model; it accepts the testing data and the trained model and
    then tests the model on it.

    input: model: torch.nn.Module, the trained model
           loader: torch.data.util.DataLoader, the object containing the testing data

    output: None, the function saves the predictions to a results.txt file
    predictions = []
    y = []
    # Iterate over the test data
    with torch.no_grad():  # We don't need to compute gradients for testing
        for [x_batch, y_batch] in loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            predicted = model(x_batch)
            predicted = predicted.cpu().numpy()
            y_batch = y_batch.cpu().numpy()
            # Rounding the predictions to 0 or 1
            predicted[predicted >= 0.5] = 1
            predicted[predicted < 0.5] = 0
            y_batch = np.reshape(y_batch, (len(y_batch), 1))
            # print(y_batch.shape, predicted.shape)
        predictions = np.vstack(predictions)
        y = np.vstack(y)
    acc = np.sum(y == predictions) / len(predictions)
    print(f'Final accuracy calculated for the validation data is: {acc}')

def test_model(model, loader, alpha, wd, dropout):
    The testing procedure of the model; it accepts the testing data and the trained model and
    then tests the model on it.

    input: model: torch.nn.Module, the trained model
           loader: torch.data.util.DataLoader, the object containing the testing data

    output: None, the function saves the predictions to a results.txt file
    predictions = []
    # Iterate over the test data
    with torch.no_grad():  # We don't need to compute gradients for testing
        for [x_batch] in loader:
            x_batch = x_batch.to(device)
            predicted = model(x_batch)
            predicted = predicted.cpu().numpy()
            # Rounding the predictions to 0 or 1
            predicted[predicted >= 0.5] = 1
            predicted[predicted < 0.5] = 0
        predictions = np.vstack(predictions)
    np.savetxt(f"results{alpha},{wd},{dropout}.txt", predictions, fmt='%i')

if __name__ == '__main__':
    TRAIN_TRIPLETS = 'train_triplets.txt'
    TEST_TRIPLETS = 'test_triplets.txt'

    # generate embedding for each image in the dataset
    if (os.path.exists('dataset/embeddings.npy') == False):

    train, val, test = get_data(TRAIN_TRIPLETS)
    for filename in train:
        if filename in val or filename in test:
            print(f'triplet {filename} in val or test')
        if filename in train:
            print('yes filename in train')

    for filename in val:
        if filename in train or filename in test:


    # # load the training and testing data
    # X, y, X_v, y_v, X_t, y_t = get_data(TRAIN_TRIPLETS)
    # X_test, _ = get_data(TEST_TRIPLETS, train=False)
    # # Create data loaders for the training and testing data
    # train_loader = create_loader_from_np(X, y, train=True, batch_size=64)
    # val_loader = create_loader_from_np(X_v, y_v, train=True, batch_size=128)
    # fval_loader = create_loader_from_np(X_t, y_t, train=True, batch_size=128)
    # test_loader = create_loader_from_np(X_test, train=False, batch_size=2048, shuffle=False)
    # # # define a model and train it
    # # model = train_model(train_loader, val_loader, fval_loader)
    # # val_model(model, fval_loader) #final validation of the accuracy on a 10% test set we split in the beginning
    # # # test the model on the test data
    # # test_model(model, test_loader)
    # # print("Results saved to results.txt")
    # for alpha in [0.001]:
    #     for wd in [0.0001, 0.001]:
    #         for d in [0.5, 1, 1.5]:
    #             print('hyperparameters:')
    #             print(alpha, wd, d)
    #             model = train_model(train_loader, val_loader, fval_loader, alpha, wd, d)
    #             val_model(model, fval_loader)
    #             # test the model on the test data
    #             test_model(model, test_loader, alpha, wd, d)
    #             print("Results saved to results.txt")