Untitled
unknown
plain_text
8 months ago
1.9 kB
4
Indexable
train_losses_step = [] # Pérdidas por paso val_losses = [] # Pérdidas de validación def train(epoch, log_interval=200, save_model_path='./model_weights'): global best_accuracy model.train() running_loss = 0 for step, data in enumerate(training_loader): input_ids = data['ids'].to(device) attention_mask = data['mask'].to(device) token_type_ids = data['token_type_ids'].to(device) targets = data['targets'].to(device) loss = training_step(input_ids, attention_mask, token_type_ids, targets, model, optimizer) running_loss += loss.item() # Almacenar la pérdida cada cierto número de pasos if step % log_interval == 0: avg_loss = running_loss / (step + 1) print(f"Epoch {epoch + 1}/{EPOCHS}, Step {step + 1}/{len(training_loader)}") print(f" Running Loss: {avg_loss:.4f}") train_losses_step.append(avg_loss) # Almacenar la pérdida por paso avg_train_loss = running_loss / len(training_loader) avg_val_loss, val_accuracy = validate() print(f"Epoch {epoch + 1}/{EPOCHS} - End of epoch") print(f" Training Loss: {avg_train_loss:.4f}") print(f" Validation Loss: {avg_val_loss:.4f}") print(f" Validation Accuracy: {val_accuracy:.4f}") val_losses.append(avg_val_loss) # Almacenar la pérdida de validación if val_accuracy > best_accuracy: best_accuracy = val_accuracy if not os.path.exists(save_model_path): os.makedirs(save_model_path) model_save_path = os.path.join(save_model_path, f"model_epoch_{epoch + 1}_acc{best_accuracy:.4f}.pth") torch.save(model.state_dict(), model_save_path) print(f"Model saved to {model_save_path}") best_accuracy = 0 for epoch in range(EPOCHS): train(epoch)
Editor is loading...
Leave a Comment