Untitled
unknown
plain_text
10 months ago
969 B
5
Indexable
def calculate_top_k_accuracy(loader, encoder, classifier, device, k=5):
encoder.eval()
classifier.eval()
top1_correct = 0
topk_correct = 0
total = 0
with torch.no_grad():
for data in tqdm(loader, total=len(loader), desc=f"Evaluating Top-{k} Accuracy"):
x, y = data[0].to(device), data[1].to(device)
total += y.size(0)
# Forward pass
encoder_output = encoder(x)
classifier_output = classifier(encoder_output)
# Get top-k predictions
_, topk_preds = classifier_output.topk(k, dim=1)
top1_preds = topk_preds[:, 0]
# Top-1 accuracy
top1_correct += (top1_preds == y).sum().item()
# Top-k accuracy
topk_correct += torch.sum(topk_preds.eq(y.view(-1, 1))).item()
classifier.train()
top1_acc = top1_correct / total * 100
topk_acc = topk_correct / total * 100Editor is loading...
Leave a Comment