Untitled

 avatar
unknown
plain_text
2 years ago
4.9 kB
10
Indexable
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

from examples.test_model1 import get_data
from models.model1 import CustomModel
from models.model2 import CustomModelWithAttention
from models.model5 import StudentModel


def reconstruct_ts(inp, amp, shift):
    reconstructed_ts = inp * amp + shift
    return reconstructed_ts


def run_model5():
    model = CustomModel()
    optimizer = optim.Adam(model.parameters())
    loss_fn = nn.MSELoss()
    X_train, y_train = get_data()

    loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=128)

    n_epochs = 2000
    train_rmse_list = []
    train_rmse_student_list = []

    for epoch in range(n_epochs):
        model.train()
        for X_batch, y_batch in loader:
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # Validation
        # if epoch % 100 != 0:
        #     continue
        model.eval()
        with torch.no_grad():
            y_pred = model(X_train)
            train_rmse = np.sqrt(loss_fn(y_pred, y_train))
            train_rmse_list.append(train_rmse.item())  # Convert to Python scalar for storage

            # y_pred = model(X_test)
            # test_rmse = np.sqrt(loss_fn(y_pred, y_test))
        print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse))

    model_student = StudentModel()
    optimizer2 = optim.Adam(model_student.parameters())

    loader = data.DataLoader(data.TensorDataset(X_train, y_pred, y_train), shuffle=True, batch_size=128)

    mse_criterion = torch.nn.MSELoss(reduction='none')
    loss1_list = []
    loss2_list = []

    for epoch in range(n_epochs):
        loss1_sum = 0.0
        loss2_sum = 0.0
        model_student.train()
        for X_batch, y_batch, y_train_batch in loader:
            y_pred_student, amp, shift, res = model_student(X_batch)
            inp_A = X_batch[:, :, 5:6]
            inp_A = inp_A[:, -1, :]
            reconstructed_y = reconstruct_ts(inp_A, amp, shift)
            loss1 = loss_fn(reconstructed_y, y_batch)
            loss1_sum += loss1.item()
            loss_prediction = mse_criterion(reconstructed_y, y_train_batch)
            loss2 = loss_fn(res, loss_prediction)
            loss2_sum += loss2.item()
            loss = loss1 + loss2
            # loss = loss1
            optimizer2.zero_grad()
            loss.backward()
            optimizer2.step()
        average_loss1 = loss1_sum / len(loader)

        # Append the average loss1 value to loss1_list
        loss1_list.append(average_loss1)

        average_loss2 = loss2_sum / len(loader)

        # Append the average loss2 value to loss1_list
        loss2_list.append(average_loss2)
        # Validation
        # if epoch % 100 != 0:
        #     continue
        model_student.eval()
        with torch.no_grad():
            y_pred_student, amp, shift, res = model_student(X_train)
            train_rmse1 = np.sqrt(loss_fn(y_pred_student, y_train))
            train_rmse_student_list.append(train_rmse1.item())  # Convert to Python scalar for storage

            # y_pred = model(X_test)
            # test_rmse = np.sqrt(loss_fn(y_pred, y_test))
        print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse1))

        # Create subplots for two plots
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(8, 10))

    # Plot the training error over epochs
    ax1.plot(range(n_epochs), train_rmse_list, label="Training Error for Teacher model")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Error")
    ax1.legend()

    # Plot the predicted y values (y_pred) and target y values (y_train)
    ax2.plot(y_train, label="Target y")
    ax2.plot(y_pred, label="Predicted y")
    ax2.set_xlabel("Sample")
    ax2.set_ylabel("Value")
    ax2.legend()

    ax3.plot(range(n_epochs), train_rmse_student_list, label="Training Error for student Model")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Error")
    ax3.legend()

    # Plot the predicted y values (y_pred) and target y values (y_train)
    ax4.plot(y_train, label="Target y")
    ax4.plot(y_pred_student, label="Predicted y")
    ax4.set_xlabel("Sample")
    ax4.set_ylabel("Value")
    ax4.legend()

    # Plot loss1 over epochs
    ax5.plot(range(n_epochs), loss1_list, label="Loss 1")
    ax5.set_xlabel("Epoch")
    ax5.set_ylabel("Loss")
    ax5.legend()

    # Plot loss2 over epochs
    ax6.plot(range(n_epochs), loss2_list, label="Loss 2")
    ax6.set_xlabel("Epoch")
    ax6.set_ylabel("Loss")
    ax6.legend()

    # Adjust layout for better visualization
    plt.tight_layout()

    # Save the figure as an image file
    plt.savefig("./plots/training_plots_5_3.png")
Editor is loading...