Untitled
unknown
python
2 years ago
750 B
10
Indexable
def run_model_on_validation_dataset(model, val_dataset):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True)
print(len(val_dataset))
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs[0]
print("Validation loss: {:.5f}".format(loss))
# Move labels to CPU and clean up the CUDA cache so we don't leak
labels = labels.cpu()
torch.cuda.empty_cache()
return labels, outputs
Editor is loading...