Untitled
unknown
plain_text
2 years ago
4.9 kB
10
Indexable
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from examples.test_model1 import get_data
from models.model1 import CustomModel
from models.model2 import CustomModelWithAttention
from models.model5 import StudentModel
def reconstruct_ts(inp, amp, shift):
reconstructed_ts = inp * amp + shift
return reconstructed_ts
def run_model5():
model = CustomModel()
optimizer = optim.Adam(model.parameters())
loss_fn = nn.MSELoss()
X_train, y_train = get_data()
loader = data.DataLoader(data.TensorDataset(X_train, y_train), shuffle=True, batch_size=128)
n_epochs = 2000
train_rmse_list = []
train_rmse_student_list = []
for epoch in range(n_epochs):
model.train()
for X_batch, y_batch in loader:
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validation
# if epoch % 100 != 0:
# continue
model.eval()
with torch.no_grad():
y_pred = model(X_train)
train_rmse = np.sqrt(loss_fn(y_pred, y_train))
train_rmse_list.append(train_rmse.item()) # Convert to Python scalar for storage
# y_pred = model(X_test)
# test_rmse = np.sqrt(loss_fn(y_pred, y_test))
print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse))
model_student = StudentModel()
optimizer2 = optim.Adam(model_student.parameters())
loader = data.DataLoader(data.TensorDataset(X_train, y_pred, y_train), shuffle=True, batch_size=128)
mse_criterion = torch.nn.MSELoss(reduction='none')
loss1_list = []
loss2_list = []
for epoch in range(n_epochs):
loss1_sum = 0.0
loss2_sum = 0.0
model_student.train()
for X_batch, y_batch, y_train_batch in loader:
y_pred_student, amp, shift, res = model_student(X_batch)
inp_A = X_batch[:, :, 5:6]
inp_A = inp_A[:, -1, :]
reconstructed_y = reconstruct_ts(inp_A, amp, shift)
loss1 = loss_fn(reconstructed_y, y_batch)
loss1_sum += loss1.item()
loss_prediction = mse_criterion(reconstructed_y, y_train_batch)
loss2 = loss_fn(res, loss_prediction)
loss2_sum += loss2.item()
loss = loss1 + loss2
# loss = loss1
optimizer2.zero_grad()
loss.backward()
optimizer2.step()
average_loss1 = loss1_sum / len(loader)
# Append the average loss1 value to loss1_list
loss1_list.append(average_loss1)
average_loss2 = loss2_sum / len(loader)
# Append the average loss2 value to loss1_list
loss2_list.append(average_loss2)
# Validation
# if epoch % 100 != 0:
# continue
model_student.eval()
with torch.no_grad():
y_pred_student, amp, shift, res = model_student(X_train)
train_rmse1 = np.sqrt(loss_fn(y_pred_student, y_train))
train_rmse_student_list.append(train_rmse1.item()) # Convert to Python scalar for storage
# y_pred = model(X_test)
# test_rmse = np.sqrt(loss_fn(y_pred, y_test))
print("Epoch %d: train RMSE %.4f" % (epoch, train_rmse1))
# Create subplots for two plots
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(8, 10))
# Plot the training error over epochs
ax1.plot(range(n_epochs), train_rmse_list, label="Training Error for Teacher model")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Error")
ax1.legend()
# Plot the predicted y values (y_pred) and target y values (y_train)
ax2.plot(y_train, label="Target y")
ax2.plot(y_pred, label="Predicted y")
ax2.set_xlabel("Sample")
ax2.set_ylabel("Value")
ax2.legend()
ax3.plot(range(n_epochs), train_rmse_student_list, label="Training Error for student Model")
ax3.set_xlabel("Epoch")
ax3.set_ylabel("Error")
ax3.legend()
# Plot the predicted y values (y_pred) and target y values (y_train)
ax4.plot(y_train, label="Target y")
ax4.plot(y_pred_student, label="Predicted y")
ax4.set_xlabel("Sample")
ax4.set_ylabel("Value")
ax4.legend()
# Plot loss1 over epochs
ax5.plot(range(n_epochs), loss1_list, label="Loss 1")
ax5.set_xlabel("Epoch")
ax5.set_ylabel("Loss")
ax5.legend()
# Plot loss2 over epochs
ax6.plot(range(n_epochs), loss2_list, label="Loss 2")
ax6.set_xlabel("Epoch")
ax6.set_ylabel("Loss")
ax6.legend()
# Adjust layout for better visualization
plt.tight_layout()
# Save the figure as an image file
plt.savefig("./plots/training_plots_5_3.png")
Editor is loading...