def run_model_on_validation_dataset(model, val_dataset, val_batch_size=16):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True)
print(len(val_dataset))
model_outputs = []
model_labels = []
val_losses = []
with torch.no_grad():
for i, batch in enumerate(val_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels_i = batch['labels'].to(device)
outputs_i = model(input_ids, attention_mask=attention_mask, labels=labels_i)
loss_i = outputs_i[0]
val_losses.append(loss_i.cpu().numpy())
model_outputs.append(outputs_i)
model_labels.append(labels_i.cpu())
print("Validation loss: {:.5f}".format(np.array(val_losses).mean()))
result = {
'loss': np.stack([x['loss'].cpu().numpy() for x in model_outputs]),
'logits': th.concat([x['logits'].cpu() for x in model_outputs]),
}
torch.cuda.empty_cache()
return torch.concat(model_labels).cpu(), result