Untitled
unknown
plain_text
a year ago
1.9 kB
6
Indexable
train_losses_step = [] # Pérdidas por paso
val_losses = [] # Pérdidas de validación
def train(epoch, log_interval=200, save_model_path='./model_weights'):
global best_accuracy
model.train()
running_loss = 0
for step, data in enumerate(training_loader):
input_ids = data['ids'].to(device)
attention_mask = data['mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)
targets = data['targets'].to(device)
loss = training_step(input_ids, attention_mask, token_type_ids, targets, model, optimizer)
running_loss += loss.item()
# Almacenar la pérdida cada cierto número de pasos
if step % log_interval == 0:
avg_loss = running_loss / (step + 1)
print(f"Epoch {epoch + 1}/{EPOCHS}, Step {step + 1}/{len(training_loader)}")
print(f" Running Loss: {avg_loss:.4f}")
train_losses_step.append(avg_loss) # Almacenar la pérdida por paso
avg_train_loss = running_loss / len(training_loader)
avg_val_loss, val_accuracy = validate()
print(f"Epoch {epoch + 1}/{EPOCHS} - End of epoch")
print(f" Training Loss: {avg_train_loss:.4f}")
print(f" Validation Loss: {avg_val_loss:.4f}")
print(f" Validation Accuracy: {val_accuracy:.4f}")
val_losses.append(avg_val_loss) # Almacenar la pérdida de validación
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
if not os.path.exists(save_model_path):
os.makedirs(save_model_path)
model_save_path = os.path.join(save_model_path, f"model_epoch_{epoch + 1}_acc{best_accuracy:.4f}.pth")
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")
best_accuracy = 0
for epoch in range(EPOCHS):
train(epoch)Editor is loading...
Leave a Comment