Untitled

mail@pastecode.io avatar
unknown
python
a year ago
750 B
2
Indexable
Never
def run_model_on_validation_dataset(model, val_dataset):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True)
    print(len(val_dataset))
    for batch in val_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs[0]
    print("Validation loss: {:.5f}".format(loss))
    
    # Move labels to CPU and clean up the CUDA cache so we don't leak
    labels = labels.cpu()
    torch.cuda.empty_cache()
    return labels, outputs