Untitled

mail@pastecode.io avatarunknown
plain_text
21 days ago
4.9 kB
2
Indexable
Never
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

from examples.test_model1 import get_data
from models.model1 import CustomModel
from models.model2 import CustomModelWithAttention
from models.model5 import StudentModel


def reconstruct_ts(inp, amp, shift):
    reconstructed_ts = inp * amp + shift
    return reconstructed_ts


def run_model5():
    model = CustomModel()
    optimizer = optim.Adam(model.parameters())
    loss_fn = nn.MSELoss()
    X_train, y_train = get_data()

    loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=128)

    n_epochs = 2000
    train_rmse_list = []
    train_rmse_student_list = []

    for epoch in range(n_epochs):
        model.train()
        for X_batch, y_batch in loader:
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # Validation
        # if epoch % 100 != 0:
        #     continue
        model.eval()
        with torch.no_grad():
            y_pred = model(X_train)
            train_rmse = np.sqrt(loss_fn(y_pred, y_train))
            train_rmse_list.append(train_rmse.item())  # Convert to Python scalar for storage

            # y_pred = model(X_test)
            # test_rmse = np.sqrt(loss_fn(y_pred, y_test))
        print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse))

    model_student = StudentModel()
    optimizer2 = optim.Adam(model_student.parameters())

    loader = data.DataLoader(data.TensorDataset(X_train, y_pred, y_train), shuffle=True, batch_size=128)

    mse_criterion = torch.nn.MSELoss(reduction='none')
    loss1_list = []
    loss2_list = []

    for epoch in range(n_epochs):
        loss1_sum = 0.0
        loss2_sum = 0.0
        model_student.train()
        for X_batch, y_batch, y_train_batch in loader:
            y_pred_student, amp, shift, res = model_student(X_batch)
            inp_A = X_batch[:, :, 5:6]
            inp_A = inp_A[:, -1, :]
            reconstructed_y = reconstruct_ts(inp_A, amp, shift)
            loss1 = loss_fn(reconstructed_y, y_batch)
            loss1_sum += loss1.item()
            loss_prediction = mse_criterion(reconstructed_y, y_train_batch)
            loss2 = loss_fn(res, loss_prediction)
            loss2_sum += loss2.item()
            loss = loss1 + loss2
            # loss = loss1
            optimizer2.zero_grad()
            loss.backward()
            optimizer2.step()
        average_loss1 = loss1_sum / len(loader)

        # Append the average loss1 value to loss1_list
        loss1_list.append(average_loss1)

        average_loss2 = loss2_sum / len(loader)

        # Append the average loss2 value to loss1_list
        loss2_list.append(average_loss2)
        # Validation
        # if epoch % 100 != 0:
        #     continue
        model_student.eval()
        with torch.no_grad():
            y_pred_student, amp, shift, res = model_student(X_train)
            train_rmse1 = np.sqrt(loss_fn(y_pred_student, y_train))
            train_rmse_student_list.append(train_rmse1.item())  # Convert to Python scalar for storage

            # y_pred = model(X_test)
            # test_rmse = np.sqrt(loss_fn(y_pred, y_test))
        print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse1))

        # Create subplots for two plots
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(8, 10))

    # Plot the training error over epochs
    ax1.plot(range(n_epochs), train_rmse_list, label="Training Error for Teacher model")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Error")
    ax1.legend()

    # Plot the predicted y values (y_pred) and target y values (y_train)
    ax2.plot(y_train, label="Target y")
    ax2.plot(y_pred, label="Predicted y")
    ax2.set_xlabel("Sample")
    ax2.set_ylabel("Value")
    ax2.legend()

    ax3.plot(range(n_epochs), train_rmse_student_list, label="Training Error for student Model")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Error")
    ax3.legend()

    # Plot the predicted y values (y_pred) and target y values (y_train)
    ax4.plot(y_train, label="Target y")
    ax4.plot(y_pred_student, label="Predicted y")
    ax4.set_xlabel("Sample")
    ax4.set_ylabel("Value")
    ax4.legend()

    # Plot loss1 over epochs
    ax5.plot(range(n_epochs), loss1_list, label="Loss 1")
    ax5.set_xlabel("Epoch")
    ax5.set_ylabel("Loss")
    ax5.legend()

    # Plot loss2 over epochs
    ax6.plot(range(n_epochs), loss2_list, label="Loss 2")
    ax6.set_xlabel("Epoch")
    ax6.set_ylabel("Loss")
    ax6.legend()

    # Adjust layout for better visualization
    plt.tight_layout()

    # Save the figure as an image file
    plt.savefig("./plots/training_plots_5_3.png")