Untitled
unknown
plain_text
2 years ago
4.9 kB
8
Indexable
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from examples.test_model1 import get_data
from models.model1 import CustomModel
from models.model2 import CustomModelWithAttention
from models.model5 import StudentModel
def reconstruct_ts(inp, amp, shift):
reconstructed_ts = inp * amp + shift
return reconstructed_ts
def run_model5():
model = CustomModel()
optimizer = optim.Adam(model.parameters())
loss_fn = nn.MSELoss()
X_train, y_train = get_data()
loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=128)
n_epochs = 2000
train_rmse_list = []
train_rmse_student_list = []
for epoch in range(n_epochs):
model.train()
for X_batch, y_batch in loader:
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validation
# if epoch % 100 != 0:
# continue
model.eval()
with torch.no_grad():
y_pred = model(X_train)
train_rmse = np.sqrt(loss_fn(y_pred, y_train))
train_rmse_list.append(train_rmse.item()) # Convert to Python scalar for storage
# y_pred = model(X_test)
# test_rmse = np.sqrt(loss_fn(y_pred, y_test))
print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse))
model_student = StudentModel()
optimizer2 = optim.Adam(model_student.parameters())
loader = data.DataLoader(data.TensorDataset(X_train, y_pred, y_train), shuffle=True, batch_size=128)
mse_criterion = torch.nn.MSELoss(reduction='none')
loss1_list = []
loss2_list = []
for epoch in range(n_epochs):
loss1_sum = 0.0
loss2_sum = 0.0
model_student.train()
for X_batch, y_batch, y_train_batch in loader:
y_pred_student, amp, shift, res = model_student(X_batch)
inp_A = X_batch[:, :, 5:6]
inp_A = inp_A[:, -1, :]
reconstructed_y = reconstruct_ts(inp_A, amp, shift)
loss1 = loss_fn(reconstructed_y, y_batch)
loss1_sum += loss1.item()
loss_prediction = mse_criterion(reconstructed_y, y_train_batch)
loss2 = loss_fn(res, loss_prediction)
loss2_sum += loss2.item()
loss = loss1 + loss2
# loss = loss1
optimizer2.zero_grad()
loss.backward()
optimizer2.step()
average_loss1 = loss1_sum / len(loader)
# Append the average loss1 value to loss1_list
loss1_list.append(average_loss1)
average_loss2 = loss2_sum / len(loader)
# Append the average loss2 value to loss1_list
loss2_list.append(average_loss2)
# Validation
# if epoch % 100 != 0:
# continue
model_student.eval()
with torch.no_grad():
y_pred_student, amp, shift, res = model_student(X_train)
train_rmse1 = np.sqrt(loss_fn(y_pred_student, y_train))
train_rmse_student_list.append(train_rmse1.item()) # Convert to Python scalar for storage
# y_pred = model(X_test)
# test_rmse = np.sqrt(loss_fn(y_pred, y_test))
print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse1))
# Create subplots for two plots
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(8, 10))
# Plot the training error over epochs
ax1.plot(range(n_epochs), train_rmse_list, label="Training Error for Teacher model")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Error")
ax1.legend()
# Plot the predicted y values (y_pred) and target y values (y_train)
ax2.plot(y_train, label="Target y")
ax2.plot(y_pred, label="Predicted y")
ax2.set_xlabel("Sample")
ax2.set_ylabel("Value")
ax2.legend()
ax3.plot(range(n_epochs), train_rmse_student_list, label="Training Error for student Model")
ax3.set_xlabel("Epoch")
ax3.set_ylabel("Error")
ax3.legend()
# Plot the predicted y values (y_pred) and target y values (y_train)
ax4.plot(y_train, label="Target y")
ax4.plot(y_pred_student, label="Predicted y")
ax4.set_xlabel("Sample")
ax4.set_ylabel("Value")
ax4.legend()
# Plot loss1 over epochs
ax5.plot(range(n_epochs), loss1_list, label="Loss 1")
ax5.set_xlabel("Epoch")
ax5.set_ylabel("Loss")
ax5.legend()
# Plot loss2 over epochs
ax6.plot(range(n_epochs), loss2_list, label="Loss 2")
ax6.set_xlabel("Epoch")
ax6.set_ylabel("Loss")
ax6.legend()
# Adjust layout for better visualization
plt.tight_layout()
# Save the figure as an image file
plt.savefig("./plots/training_plots_5_3.png")
Editor is loading...