Untitled
def iou_score(pred, target): smooth = 1e-5 pred_bin = (pred > 0.5).float() intersection = (pred_bin * target).sum((1, 2)) union = pred_bin.sum((1, 2)) + target.sum((1, 2)) - intersection iou = (intersection + smooth) / (union + smooth) return iou.mean() # Modificare la funzione evaluate_sequences per includere IoU: def evaluate_sequences(model, test_loader, device, post_processor=None): model.eval() total_dice = 0 total_iou = 0 total_consistency = 0 num_sequences = 0 current_sequence_preds = [] current_sequence_masks = [] with torch.no_grad(): for images, masks in test_loader: images = images.to(device) outputs = model(images) preds = outputs.cpu().numpy() current_sequence_preds.extend([pred[0] for pred in preds]) current_sequence_masks.extend([mask.numpy() for mask in masks]) if len(current_sequence_preds) >= 30: sequence_preds = np.stack(current_sequence_preds) sequence_masks = np.stack(current_sequence_masks) if post_processor is not None: sequence_preds = post_processor.process_sequence(sequence_preds) dice = dice_score(torch.from_numpy(sequence_preds), torch.from_numpy(sequence_masks)) iou = iou_score(torch.from_numpy(sequence_preds), torch.from_numpy(sequence_masks)) total_dice += dice total_iou += iou consistency = evaluate_temporal_consistency(sequence_preds) total_consistency += consistency num_sequences += 1 current_sequence_preds = [] current_sequence_masks = [] avg_dice = total_dice / num_sequences if num_sequences > 0 else 0 avg_iou = total_iou / num_sequences if num_sequences > 0 else 0 avg_consistency = total_consistency / num_sequences if num_sequences > 0 else 0 return avg_dice, avg_iou, avg_consistency # E poi modifica la parte dove chiami la valutazione: print("\nInizio valutazione completa...") test_dice, test_iou, test_consistency = evaluate_sequences(model, test_loader, device, post_processor=post_processor) print(f'\nDice Score sul test set (sequenze): {test_dice:.4f}') print(f'IoU Score sul test set (sequenze): {test_iou:.4f}') print(f'Temporal Consistency Score: {test_consistency:.4f}')
Leave a Comment