Untitled

 avatar
unknown
plain_text
a month ago
969 B
3
Indexable
def calculate_top_k_accuracy(loader, encoder, classifier, device, k=5):
    encoder.eval()
    classifier.eval()
    top1_correct = 0
    topk_correct = 0
    total = 0

    with torch.no_grad():
        for data in tqdm(loader, total=len(loader), desc=f"Evaluating Top-{k} Accuracy"):
            x, y = data[0].to(device), data[1].to(device)
            total += y.size(0)

            # Forward pass
            encoder_output = encoder(x)
            classifier_output = classifier(encoder_output)

            # Get top-k predictions
            _, topk_preds = classifier_output.topk(k, dim=1)
            top1_preds = topk_preds[:, 0]

            # Top-1 accuracy
            top1_correct += (top1_preds == y).sum().item()

            # Top-k accuracy
            topk_correct += torch.sum(topk_preds.eq(y.view(-1, 1))).item()
    classifier.train()
    top1_acc = top1_correct / total * 100
    topk_acc = topk_correct / total * 100
Leave a Comment