Untitled

 avatar
unknown
python
13 days ago
2.6 kB
6
Indexable
def iou_score(pred, target):
    smooth = 1e-5
    pred_bin = (pred > 0.5).float()
    intersection = (pred_bin * target).sum((1, 2))
    union = pred_bin.sum((1, 2)) + target.sum((1, 2)) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean()

# Modificare la funzione evaluate_sequences per includere IoU:
def evaluate_sequences(model, test_loader, device, post_processor=None):
    model.eval()
    total_dice = 0
    total_iou = 0
    total_consistency = 0
    num_sequences = 0
    current_sequence_preds = []
    current_sequence_masks = []
    
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.cpu().numpy()
            current_sequence_preds.extend([pred[0] for pred in preds])
            current_sequence_masks.extend([mask.numpy() for mask in masks])
            
            if len(current_sequence_preds) >= 30:
                sequence_preds = np.stack(current_sequence_preds)
                sequence_masks = np.stack(current_sequence_masks)
                
                if post_processor is not None:
                    sequence_preds = post_processor.process_sequence(sequence_preds)
                
                dice = dice_score(torch.from_numpy(sequence_preds),
                                torch.from_numpy(sequence_masks))
                iou = iou_score(torch.from_numpy(sequence_preds),  
                              torch.from_numpy(sequence_masks))   
                
                total_dice += dice
                total_iou += iou  
                consistency = evaluate_temporal_consistency(sequence_preds)
                total_consistency += consistency
                num_sequences += 1
                current_sequence_preds = []
                current_sequence_masks = []

    avg_dice = total_dice / num_sequences if num_sequences > 0 else 0
    avg_iou = total_iou / num_sequences if num_sequences > 0 else 0 
    avg_consistency = total_consistency / num_sequences if num_sequences > 0 else 0
    
    return avg_dice, avg_iou, avg_consistency  

# E poi modifica la parte dove chiami la valutazione:

print("\nInizio valutazione completa...")
test_dice, test_iou, test_consistency = evaluate_sequences(model, test_loader, device, post_processor=post_processor)
print(f'\nDice Score sul test set (sequenze): {test_dice:.4f}')
print(f'IoU Score sul test set (sequenze): {test_iou:.4f}')
print(f'Temporal Consistency Score: {test_consistency:.4f}')
Leave a Comment