Untitled
def calculate_top_k_accuracy(loader, encoder, classifier, device, k=5): encoder.eval() classifier.eval() top1_correct = 0 topk_correct = 0 total = 0 with torch.no_grad(): for data in tqdm(loader, total=len(loader), desc=f"Evaluating Top-{k} Accuracy"): x, y = data[0].to(device), data[1].to(device) total += y.size(0) # Forward pass encoder_output = encoder(x) classifier_output = classifier(encoder_output) # Get top-k predictions _, topk_preds = classifier_output.topk(k, dim=1) top1_preds = topk_preds[:, 0] # Top-1 accuracy top1_correct += (top1_preds == y).sum().item() # Top-k accuracy topk_correct += torch.sum(topk_preds.eq(y.view(-1, 1))).item() classifier.train() top1_acc = top1_correct / total * 100 topk_acc = topk_correct / total * 100
Leave a Comment