Untitled

 avatar
unknown
plain_text
8 months ago
1.9 kB
4
Indexable
train_losses_step = []  # Pérdidas por paso
val_losses = []  # Pérdidas de validación

def train(epoch, log_interval=200, save_model_path='./model_weights'):
    global best_accuracy
    model.train()
    running_loss = 0

    for step, data in enumerate(training_loader):
        input_ids = data['ids'].to(device)
        attention_mask = data['mask'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        targets = data['targets'].to(device)

        loss = training_step(input_ids, attention_mask, token_type_ids, targets, model, optimizer)
        running_loss += loss.item()

        # Almacenar la pérdida cada cierto número de pasos
        if step % log_interval == 0:
            avg_loss = running_loss / (step + 1)
            print(f"Epoch {epoch + 1}/{EPOCHS}, Step {step + 1}/{len(training_loader)}")
            print(f"  Running Loss: {avg_loss:.4f}")
            train_losses_step.append(avg_loss)  # Almacenar la pérdida por paso

    avg_train_loss = running_loss / len(training_loader)

    avg_val_loss, val_accuracy = validate()

    print(f"Epoch {epoch + 1}/{EPOCHS} - End of epoch")
    print(f"  Training Loss: {avg_train_loss:.4f}")
    print(f"  Validation Loss: {avg_val_loss:.4f}")
    print(f"  Validation Accuracy: {val_accuracy:.4f}")

    val_losses.append(avg_val_loss)  # Almacenar la pérdida de validación

    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        if not os.path.exists(save_model_path):
            os.makedirs(save_model_path)

        model_save_path = os.path.join(save_model_path, f"model_epoch_{epoch + 1}_acc{best_accuracy:.4f}.pth")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")

best_accuracy = 0
for epoch in range(EPOCHS):
    train(epoch)
Editor is loading...
Leave a Comment