Untitled
unknown
plain_text
4 years ago
1.8 kB
5
Indexable
from IPython.display import clear_output def train(model, optimizer, n_epochs=5): test_losses = [] for epoch in range(n_epochs+1): # тренировка for x_train, y_train in tqdm(train_dataloader_big): optimizer.zero_grad() x_train = x_train.to(device) y_pred = model(x_train) y_train = y_train.type(torch.float).to(device) y_pred = y_pred.type(torch.float).t()[0].to(device) loss = torch.sqrt(criterion(y_pred, y_train)) loss.backward() optimizer.step() optimizer.zero_grad() # валидация if epoch % 1 == 0: val_loss = [] with torch.no_grad(): for x_val, y_val in tqdm(test_dataloader_big): y_val = y_val*sigma + mean #масштабирую обратно y_pred = model(x_val.to(device))*sigma + mean y_val = y_val.type(torch.float).to(device) y_pred = y_pred.type(torch.float).t()[0].to(device) # преобразование, аналогичное reshape(-1, 1) val_loss = torch.sqrt(criterion(y_pred, y_val)) test_losses.append(val_loss) # печатаем метрики scheduler.step() clear_output(wait=True) print(y_pred) # смотрю на массивы для проверки print(y_val) plt.plot([i for i in range(len(test_losses))], list(test_losses), label = 'test loss') plt.show() print(test_losses) print(f"Epoch: {epoch}, loss: {val_loss}")
Editor is loading...