Untitled
unknown
plain_text
4 years ago
1.8 kB
8
Indexable
from IPython.display import clear_output
def train(model, optimizer, n_epochs=5):
test_losses = []
for epoch in range(n_epochs+1):
# тренировка
for x_train, y_train in tqdm(train_dataloader_big):
optimizer.zero_grad()
x_train = x_train.to(device)
y_pred = model(x_train)
y_train = y_train.type(torch.float).to(device)
y_pred = y_pred.type(torch.float).t()[0].to(device)
loss = torch.sqrt(criterion(y_pred, y_train))
loss.backward()
optimizer.step()
optimizer.zero_grad()
# валидация
if epoch % 1 == 0:
val_loss = []
with torch.no_grad():
for x_val, y_val in tqdm(test_dataloader_big):
y_val = y_val*sigma + mean #масштабирую обратно
y_pred = model(x_val.to(device))*sigma + mean
y_val = y_val.type(torch.float).to(device)
y_pred = y_pred.type(torch.float).t()[0].to(device) # преобразование, аналогичное reshape(-1, 1)
val_loss = torch.sqrt(criterion(y_pred, y_val))
test_losses.append(val_loss)
# печатаем метрики
scheduler.step()
clear_output(wait=True)
print(y_pred) # смотрю на массивы для проверки
print(y_val)
plt.plot([i for i in range(len(test_losses))], list(test_losses), label = 'test loss')
plt.show()
print(test_losses)
print(f"Epoch: {epoch}, loss: {val_loss}")Editor is loading...