val_loss_updated

 avatar
unknown
python
2 years ago
1.1 kB
7
Indexable
def run_model_on_validation_dataset(model, val_dataset, val_batch_size=16):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

  val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True)
  print(len(val_dataset))

  model_outputs = []
  model_labels = []
  val_losses = []
  with torch.no_grad():
    for i, batch in enumerate(val_loader):
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels_i = batch['labels'].to(device)
      outputs_i = model(input_ids, attention_mask=attention_mask, labels=labels_i)
      loss_i = outputs_i[0]
      val_losses.append(loss_i.cpu().numpy())
      model_outputs.append(outputs_i)
      model_labels.append(labels_i.cpu())

  print("Validation loss: {:.5f}".format(np.array(val_losses).mean()))

  result = {
    'loss': np.stack([x['loss'].cpu().numpy() for x in model_outputs]),
    'logits': th.concat([x['logits'].cpu() for x in model_outputs]),
  }
  torch.cuda.empty_cache()
  return torch.concat(model_labels).cpu(), result