Untitled
plain_text
21 days ago
4.9 kB
2
Indexable
Never
import numpy as np import torch import torch.optim as optim import torch.utils.data as data import torch.nn as nn from sklearn.preprocessing import StandardScaler import matplotlib.pyplot as plt from examples.test_model1 import get_data from models.model1 import CustomModel from models.model2 import CustomModelWithAttention from models.model5 import StudentModel def reconstruct_ts(inp, amp, shift): reconstructed_ts = inp * amp + shift return reconstructed_ts def run_model5(): model = CustomModel() optimizer = optim.Adam(model.parameters()) loss_fn = nn.MSELoss() X_train, y_train = get_data() loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=128) n_epochs = 2000 train_rmse_list = [] train_rmse_student_list = [] for epoch in range(n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() # Validation # if epoch % 100 != 0: # continue model.eval() with torch.no_grad(): y_pred = model(X_train) train_rmse = np.sqrt(loss_fn(y_pred, y_train)) train_rmse_list.append(train_rmse.item()) # Convert to Python scalar for storage # y_pred = model(X_test) # test_rmse = np.sqrt(loss_fn(y_pred, y_test)) print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse)) model_student = StudentModel() optimizer2 = optim.Adam(model_student.parameters()) loader = data.DataLoader(data.TensorDataset(X_train, y_pred, y_train), shuffle=True, batch_size=128) mse_criterion = torch.nn.MSELoss(reduction='none') loss1_list = [] loss2_list = [] for epoch in range(n_epochs): loss1_sum = 0.0 loss2_sum = 0.0 model_student.train() for X_batch, y_batch, y_train_batch in loader: y_pred_student, amp, shift, res = model_student(X_batch) inp_A = X_batch[:, :, 5:6] inp_A = inp_A[:, -1, :] reconstructed_y = reconstruct_ts(inp_A, amp, shift) loss1 = loss_fn(reconstructed_y, y_batch) loss1_sum += loss1.item() loss_prediction = mse_criterion(reconstructed_y, y_train_batch) loss2 = loss_fn(res, loss_prediction) loss2_sum += loss2.item() loss = loss1 + loss2 # loss = loss1 optimizer2.zero_grad() loss.backward() optimizer2.step() average_loss1 = loss1_sum / len(loader) # Append the average loss1 value to loss1_list loss1_list.append(average_loss1) average_loss2 = loss2_sum / len(loader) # Append the average loss2 value to loss1_list loss2_list.append(average_loss2) # Validation # if epoch % 100 != 0: # continue model_student.eval() with torch.no_grad(): y_pred_student, amp, shift, res = model_student(X_train) train_rmse1 = np.sqrt(loss_fn(y_pred_student, y_train)) train_rmse_student_list.append(train_rmse1.item()) # Convert to Python scalar for storage # y_pred = model(X_test) # test_rmse = np.sqrt(loss_fn(y_pred, y_test)) print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse1)) # Create subplots for two plots fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(8, 10)) # Plot the training error over epochs ax1.plot(range(n_epochs), train_rmse_list, label="Training Error for Teacher model") ax1.set_xlabel("Epoch") ax1.set_ylabel("Error") ax1.legend() # Plot the predicted y values (y_pred) and target y values (y_train) ax2.plot(y_train, label="Target y") ax2.plot(y_pred, label="Predicted y") ax2.set_xlabel("Sample") ax2.set_ylabel("Value") ax2.legend() ax3.plot(range(n_epochs), train_rmse_student_list, label="Training Error for student Model") ax3.set_xlabel("Epoch") ax3.set_ylabel("Error") ax3.legend() # Plot the predicted y values (y_pred) and target y values (y_train) ax4.plot(y_train, label="Target y") ax4.plot(y_pred_student, label="Predicted y") ax4.set_xlabel("Sample") ax4.set_ylabel("Value") ax4.legend() # Plot loss1 over epochs ax5.plot(range(n_epochs), loss1_list, label="Loss 1") ax5.set_xlabel("Epoch") ax5.set_ylabel("Loss") ax5.legend() # Plot loss2 over epochs ax6.plot(range(n_epochs), loss2_list, label="Loss 2") ax6.set_xlabel("Epoch") ax6.set_ylabel("Loss") ax6.legend() # Adjust layout for better visualization plt.tight_layout() # Save the figure as an image file plt.savefig("./plots/training_plots_5_3.png")