val_loss_updated
unknown
python
7 months ago
1.1 kB
6
Indexable
Never
def run_model_on_validation_dataset(model, val_dataset, val_batch_size=16): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True) print(len(val_dataset)) model_outputs = [] model_labels = [] val_losses = [] with torch.no_grad(): for i, batch in enumerate(val_loader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels_i = batch['labels'].to(device) outputs_i = model(input_ids, attention_mask=attention_mask, labels=labels_i) loss_i = outputs_i[0] val_losses.append(loss_i.cpu().numpy()) model_outputs.append(outputs_i) model_labels.append(labels_i.cpu()) print("Validation loss: {:.5f}".format(np.array(val_losses).mean())) result = { 'loss': np.stack([x['loss'].cpu().numpy() for x in model_outputs]), 'logits': th.concat([x['logits'].cpu() for x in model_outputs]), } torch.cuda.empty_cache() return torch.concat(model_labels).cpu(), result