Untitled

 avatar
unknown
plain_text
4 years ago
1.8 kB
5
Indexable
from IPython.display import clear_output

def train(model, optimizer, n_epochs=5):
    test_losses = []
    for epoch in range(n_epochs+1):
        # тренировка
        for x_train, y_train in tqdm(train_dataloader_big):
            optimizer.zero_grad()
            x_train = x_train.to(device)
            y_pred = model(x_train)
            y_train = y_train.type(torch.float).to(device)
            y_pred = y_pred.type(torch.float).t()[0].to(device)


            loss = torch.sqrt(criterion(y_pred, y_train))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # валидация

        if epoch % 1 == 0:
            val_loss = []
            
            with torch.no_grad():
                for x_val, y_val in tqdm(test_dataloader_big):
                    y_val = y_val*sigma + mean #масштабирую обратно
                    y_pred = model(x_val.to(device))*sigma + mean 
                    
                    y_val = y_val.type(torch.float).to(device)
                    y_pred = y_pred.type(torch.float).t()[0].to(device) # преобразование, аналогичное reshape(-1, 1)

                    val_loss = torch.sqrt(criterion(y_pred, y_val))
                    test_losses.append(val_loss)

            # печатаем метрики
            
        scheduler.step()
        clear_output(wait=True)
        print(y_pred) # смотрю на массивы для проверки
        print(y_val)

        plt.plot([i for i in range(len(test_losses))], list(test_losses), label = 'test loss')
        plt.show()
        print(test_losses)
        print(f"Epoch: {epoch}, loss: {val_loss}")
Editor is loading...