Untitled

 avatar
unknown
plain_text
a month ago
21 kB
6
Indexable
import argparse
import json
import os
import os.path as osp
import random
import sys
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.base import clone
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

PROJECT_ROOT = osp.abspath(osp.join(osp.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from simclr.models_simple import Lightweight3DResNet
from train_simclr import SimCLRRunner
from utils.utils import load_config


@dataclass(frozen=True)
class PairRecord:
    query_key: Tuple[str, int]
    candidate_key: Tuple[str, int]
    label: int
    trajectory_id: str


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Arrow-of-time probe: same-trajectory future-vs-past classification."
    )
    parser.add_argument(
        "--config",
        type=str,
        default="/home/dhruvagarwal/projects/MitoSpace4D/simclr/config.yaml",
        help="Path to SimCLR config.",
    )
    parser.add_argument(
        "--checkpoint-path",
        type=str,
        default="/home/dhruvagarwal/projects/MitoSpace4D/runs/lightning_logs/resnetbilstm_encoded_normal_tmeporal_consistent_mtg_only/checkpoints/epoch=250-step=40411-val_loss=0.00.ckpt",
        help="Path to frozen SSL checkpoint.",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="/mnt/aquila/others/MitoSpace4D/2024v3_data/processed_data",
        help="Root containing per-compound folders of trajectories (.npy).",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="/home/dhruvagarwal/projects/MitoSpace4D/temporal_experiments/results/arrow_of_time_probe",
        help="Directory for experiment outputs.",
    )
    parser.add_argument(
        "--clip-len",
        type=int,
        default=5,
        help="Fixed odd clip length centered at t.",
    )
    parser.add_argument(
        "--deltas",
        type=int,
        nargs="+",
        default=[2, 4, 6, 8],
        help="Temporal offsets (in frames).",
    )
    parser.add_argument(
        "--feature-forms",
        type=str,
        nargs="+",
        default=["concat", "diff", "rich"],
        choices=["concat", "diff", "rich"],
        help="Probe feature construction(s).",
    )
    parser.add_argument(
        "--embedding-pool",
        type=str,
        default="last",
        choices=["last", "mean"],
        help="How to pool per-frame clip embeddings.",
    )
    parser.add_argument(
        "--test-size",
        type=float,
        default=0.2,
        help="Held-out trajectory fraction per compound.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed.",
    )
    parser.add_argument(
        "--max-trajectories-per-compound",
        type=int,
        default=None,
        help="Optional cap for debugging/smoke runs.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device for embedding extraction. Use cpu/cuda.",
    )
    return parser.parse_args()


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def list_compound_dirs(data_root: str) -> List[str]:
    return sorted(
        [
            osp.join(data_root, d)
            for d in os.listdir(data_root)
            if osp.isdir(osp.join(data_root, d))
        ]
    )


def move_axis_to_front(arr: np.ndarray, axis: int) -> np.ndarray:
    if axis == 0:
        return arr
    return np.moveaxis(arr, axis, 0)


def preprocess_trajectory(arr: np.ndarray) -> np.ndarray:
    """
    Convert trajectory array to shape (T, 1, Z, H, W), MTG-only channel.
    """
    arr = np.asarray(arr, dtype=np.float32)
    if arr.ndim not in (4, 5):
        raise ValueError(f"Unsupported trajectory ndim={arr.ndim}, shape={arr.shape}")

    t_axes = [i for i, s in enumerate(arr.shape) if s == 20]
    if not t_axes:
        raise ValueError(f"Could not identify time axis (size 20) for shape {arr.shape}")
    arr = move_axis_to_front(arr, t_axes[0])

    if arr.ndim == 4:
        arr = arr[:, None, ...]  # (T, C=1, Z, H, W)
    else:
        # Identify channel axis among non-time dimensions.
        # In this data channels are typically 1 or 2.
        candidate_channel_axes = [i for i, s in enumerate(arr.shape[1:], start=1) if s in (1, 2, 3, 4)]
        channel_axis = candidate_channel_axes[0] if candidate_channel_axes else 1
        arr = np.moveaxis(arr, channel_axis, 1)

    if arr.ndim != 5:
        raise ValueError(f"Failed to convert to 5D tensor. Got shape {arr.shape}")

    channel_idx = 1 if arr.shape[1] > 1 else 0  # MTG when available, otherwise only channel.
    arr = arr[:, channel_idx:channel_idx + 1]

    arr_min = float(arr.min())
    arr_max = float(arr.max())
    arr = (arr - arr_min) / (arr_max - arr_min + 1e-8)
    return arr.astype(np.float32)


def get_valid_centers(num_frames: int, clip_len: int, delta: int) -> List[int]:
    half = clip_len // 2
    left = half + delta
    right = (num_frames - 1) - half - delta
    if left > right:
        return []
    return list(range(left, right + 1))


def extract_clip_embedding(
    model: torch.nn.Module,
    clip: np.ndarray,
    device: str,
    embedding_pool: str,
) -> np.ndarray:
    clip_tensor = torch.from_numpy(clip).unsqueeze(0).to(device)  # (1, T, C, Z, H, W)
    with torch.no_grad():
        features, _, _ = model(clip_tensor)
        features = F.normalize(features, dim=-1)
        if embedding_pool == "last":
            emb = features[:, -1, :]
        else:
            emb = features.mean(dim=1)
    return emb.squeeze(0).detach().cpu().numpy().astype(np.float32)


def build_features(
    pair_records: Sequence[PairRecord],
    emb_cache: Dict[Tuple[str, int], np.ndarray],
    feature_form: str,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    x_rows, y_rows, groups = [], [], []
    for rec in pair_records:
        q = emb_cache[rec.query_key]
        c = emb_cache[rec.candidate_key]
        if feature_form == "concat":
            feat = np.concatenate([q, c], axis=0)
        elif feature_form == "diff":
            feat = c - q
        elif feature_form == "rich":
            feat = np.concatenate([q, c, c - q, q * c], axis=0)
        else:
            raise ValueError(f"Unknown feature form: {feature_form}")
        x_rows.append(feat)
        y_rows.append(rec.label)
        groups.append(rec.trajectory_id)
    return np.stack(x_rows), np.array(y_rows), np.array(groups)


def grouped_split_indices(
    y: np.ndarray, groups: np.ndarray, test_size: float, seed: int, tries: int = 32
) -> Tuple[np.ndarray, np.ndarray]:
    unique_groups = np.unique(groups)
    n_groups = len(unique_groups)
    if n_groups < 2:
        raise RuntimeError("Need at least 2 trajectories for grouped holdout.")

    n_test_groups = max(1, int(round(test_size * n_groups)))
    if n_test_groups >= n_groups:
        n_test_groups = n_groups - 1
    effective_test_size = n_test_groups / float(n_groups)

    for i in range(tries):
        splitter = GroupShuffleSplit(n_splits=1, test_size=effective_test_size, random_state=seed + i)
        train_idx, test_idx = next(splitter.split(np.zeros_like(y), y, groups=groups))
        y_train, y_test = y[train_idx], y[test_idx]
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            continue
        return train_idx, test_idx
    raise RuntimeError("Could not produce valid grouped split with both classes in train/test.")


def evaluate_probe(
    x: np.ndarray, y: np.ndarray, groups: np.ndarray, test_size: float, seed: int
) -> Dict[str, float]:
    train_idx, test_idx = grouped_split_indices(y, groups, test_size, seed)
    x_train, x_test = x[train_idx], x[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    g_train, g_test = groups[train_idx], groups[test_idx]

    clf = Pipeline(
        [
            ("scaler", StandardScaler()),
            (
                "logreg",
                LogisticRegression(
                    max_iter=3000,
                    random_state=seed,
                    class_weight="balanced",
                ),
            ),
        ]
    )
    clf.fit(x_train, y_train)
    probs = clf.predict_proba(x_test)[:, 1]
    preds = (probs >= 0.5).astype(np.int32)

    label_shuffle_clf = clone(clf)
    rng = np.random.default_rng(seed)
    y_train_shuffled = rng.permutation(y_train)
    label_shuffle_clf.fit(x_train, y_train_shuffled)
    sanity_probs = label_shuffle_clf.predict_proba(x_test)[:, 1]
    sanity_preds = (sanity_probs >= 0.5).astype(np.int32)

    split_overlap = len(set(g_train).intersection(set(g_test)))
    train_pos = float(y_train.mean())
    test_pos = float(y_test.mean())
    metrics = {
        "n_samples": int(len(y)),
        "n_train": int(len(train_idx)),
        "n_test": int(len(test_idx)),
        "n_train_trajectories": int(len(set(g_train.tolist()))),
        "n_test_trajectories": int(len(set(g_test.tolist()))),
        "split_overlap_trajectories": int(split_overlap),
        "train_positive_rate": train_pos,
        "test_positive_rate": test_pos,
        "accuracy": float(accuracy_score(y_test, preds)),
        "auc": float(roc_auc_score(y_test, probs)),
        "sanity_label_shuffle_accuracy": float(accuracy_score(y_test, sanity_preds)),
        "sanity_label_shuffle_auc": float(roc_auc_score(y_test, sanity_probs)),
        "chance_baseline": 0.5,
    }
    return metrics


def load_model(config_path: str, checkpoint_path: str, device: str) -> torch.nn.Module:
    cfg = load_config(config_path)
    backbone = Lightweight3DResNet(
        embedding_size=2048, cfg_aug=cfg["data_params"]["transforms"], apply_aug=False
    )
    runner = SimCLRRunner.load_from_checkpoint(
        checkpoint_path, model=backbone, cfg=cfg, strict=False
    )
    runner.eval()
    model = runner.model.to(device)
    model.eval()
    return model


def plot_metric_curve(
    df: pd.DataFrame,
    y_col: str,
    out_path: str,
    title: str,
    ylabel: str,
) -> None:
    x_vals = df["delta"].to_numpy()
    y_vals = df[y_col].to_numpy()
    plt.figure(figsize=(7, 4))
    plt.plot(x_vals, y_vals, marker="o", linewidth=2)
    plt.axhline(0.5, linestyle="--", linewidth=1.5)
    plt.ylim(0.0, 1.0)
    plt.xlabel("Delta (frames)")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True, alpha=0.25)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


def run_compound(
    compound_dir: str,
    model: torch.nn.Module,
    args: argparse.Namespace,
    device: str,
) -> Tuple[pd.DataFrame, Dict]:
    compound_name = osp.basename(compound_dir.rstrip("/"))
    file_paths = sorted(
        [
            osp.join(compound_dir, f)
            for f in os.listdir(compound_dir)
            if f.endswith(".npy") and osp.isfile(osp.join(compound_dir, f))
        ]
    )
    if args.max_trajectories_per_compound is not None:
        file_paths = file_paths[: args.max_trajectories_per_compound]

    emb_cache: Dict[Tuple[str, int], np.ndarray] = {}
    pairs_by_delta: Dict[int, List[PairRecord]] = defaultdict(list)
    skipped_traj = 0
    skipped_delta_no_centers = defaultdict(int)

    for fpath in tqdm(file_paths, desc=f"{compound_name}: trajectories"):
        traj_id = osp.basename(fpath)
        try:
            arr = np.load(fpath, mmap_mode="r")
            arr = preprocess_trajectory(arr)
        except Exception:
            skipped_traj += 1
            continue

        num_frames = arr.shape[0]
        half = args.clip_len // 2
        needed_centers = set()

        valid_centers_by_delta = {}
        for delta in args.deltas:
            centers = get_valid_centers(num_frames, args.clip_len, delta)
            valid_centers_by_delta[delta] = centers
            if not centers:
                skipped_delta_no_centers[delta] += 1
                continue
            for t in centers:
                needed_centers.add(t)
                needed_centers.add(t - delta)
                needed_centers.add(t + delta)

        # Cache embeddings for all needed clip centers from this trajectory.
        for center in needed_centers:
            cache_key = (traj_id, center)
            if cache_key in emb_cache:
                continue
            clip = arr[center - half : center + half + 1]  # (clip_len, 1, Z, H, W)
            emb_cache[cache_key] = extract_clip_embedding(
                model=model, clip=clip, device=device, embedding_pool=args.embedding_pool
            )

        # Build balanced labels per center: future=1 and past=0
        for delta in args.deltas:
            centers = valid_centers_by_delta[delta]
            if not centers:
                continue
            for t in centers:
                qk = (traj_id, t)
                fk = (traj_id, t + delta)
                pk = (traj_id, t - delta)
                pairs_by_delta[delta].append(
                    PairRecord(query_key=qk, candidate_key=fk, label=1, trajectory_id=traj_id)
                )
                pairs_by_delta[delta].append(
                    PairRecord(query_key=qk, candidate_key=pk, label=0, trajectory_id=traj_id)
                )

    rows = []
    for delta in args.deltas:
        pair_records = pairs_by_delta.get(delta, [])
        if not pair_records:
            rows.append(
                {
                    "compound": compound_name,
                    "delta": int(delta),
                    "feature_form": "concat",
                    "status": "skipped_no_valid_pairs",
                }
            )
            continue

        for feature_form in args.feature_forms:
            x, y, groups = build_features(pair_records, emb_cache, feature_form)
            unique_group_count = len(np.unique(groups))
            if unique_group_count < 2:
                rows.append(
                    {
                        "compound": compound_name,
                        "delta": int(delta),
                        "feature_form": feature_form,
                        "status": "skipped_insufficient_trajectories",
                        "n_samples": int(len(y)),
                        "n_unique_trajectories": int(unique_group_count),
                    }
                )
                continue

            metrics = evaluate_probe(x, y, groups, test_size=args.test_size, seed=args.seed)
            row = {
                "compound": compound_name,
                "delta": int(delta),
                "feature_form": feature_form,
                "status": "ok",
            }
            row.update(metrics)
            rows.append(row)

    result_df = pd.DataFrame(rows).sort_values(["delta", "feature_form"])
    metadata = {
        "compound": compound_name,
        "n_trajectories_seen": len(file_paths),
        "n_trajectories_skipped": skipped_traj,
        "n_cached_clips": len(emb_cache),
        "skipped_delta_no_centers_count": {str(k): int(v) for k, v in skipped_delta_no_centers.items()},
        "clip_len": int(args.clip_len),
        "deltas": [int(d) for d in args.deltas],
        "embedding_pool": args.embedding_pool,
    }
    return result_df, metadata


def main() -> None:
    args = parse_args()
    if args.clip_len % 2 == 0:
        raise ValueError("--clip-len must be odd to define a center frame.")

    set_seed(args.seed)
    ensure_dir(args.output_dir)
    compound_out_root = osp.join(args.output_dir, "per_compound")
    ensure_dir(compound_out_root)

    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA requested but unavailable. Falling back to CPU.")
        device = "cpu"

    model = load_model(
        config_path=args.config, checkpoint_path=args.checkpoint_path, device=device
    )

    compound_dirs = list_compound_dirs(args.data_root)
    all_rows = []
    all_metadata = []

    for compound_dir in compound_dirs:
        compound_name = osp.basename(compound_dir.rstrip("/"))
        print(f"\n==== Processing compound: {compound_name} ====")
        compound_df, metadata = run_compound(
            compound_dir=compound_dir, model=model, args=args, device=device
        )
        all_rows.append(compound_df)
        all_metadata.append(metadata)

        compound_out_dir = osp.join(compound_out_root, compound_name)
        ensure_dir(compound_out_dir)
        compound_csv = osp.join(compound_out_dir, "metrics.csv")
        compound_json = osp.join(compound_out_dir, "metadata.json")
        compound_df.to_csv(compound_csv, index=False)
        with open(compound_json, "w", encoding="utf-8") as f:
            json.dump(metadata, f, indent=2)

        primary_df = compound_df[
            (compound_df["status"] == "ok") & (compound_df["feature_form"] == "concat")
        ].sort_values("delta")
        if len(primary_df) > 0:
            plot_metric_curve(
                primary_df,
                y_col="accuracy",
                out_path=osp.join(compound_out_dir, "accuracy_vs_delta.png"),
                title=f"{compound_name}: Accuracy vs Delta",
                ylabel="Accuracy",
            )
            plot_metric_curve(
                primary_df,
                y_col="auc",
                out_path=osp.join(compound_out_dir, "auc_vs_delta.png"),
                title=f"{compound_name}: AUC vs Delta",
                ylabel="ROC-AUC",
            )

    if not all_rows:
        raise RuntimeError("No compound results were produced.")

    all_results_df = pd.concat(all_rows, ignore_index=True)
    all_results_path = osp.join(args.output_dir, "all_compounds_metrics.csv")
    all_results_df.to_csv(all_results_path, index=False)

    with open(osp.join(args.output_dir, "run_metadata.json"), "w", encoding="utf-8") as f:
        json.dump(
            {
                "config": args.config,
                "checkpoint_path": args.checkpoint_path,
                "data_root": args.data_root,
                "clip_len": args.clip_len,
                "deltas": args.deltas,
                "feature_forms": args.feature_forms,
                "embedding_pool": args.embedding_pool,
                "test_size": args.test_size,
                "seed": args.seed,
                "device": device,
                "n_compounds": len(compound_dirs),
                "per_compound_metadata": all_metadata,
            },
            f,
            indent=2,
        )

    primary = all_results_df[
        (all_results_df["status"] == "ok") & (all_results_df["feature_form"] == "concat")
    ].copy()
    if len(primary) > 0:
        agg = (
            primary.groupby("delta")[["accuracy", "auc"]]
            .agg(["mean", "std", "count"])
            .reset_index()
        )
        agg.columns = [
            "delta",
            "accuracy_mean",
            "accuracy_std",
            "accuracy_count",
            "auc_mean",
            "auc_std",
            "auc_count",
        ]
        agg.to_csv(osp.join(args.output_dir, "aggregate_primary_concat.csv"), index=False)

        plt.figure(figsize=(7, 4))
        plt.errorbar(
            agg["delta"].to_numpy(),
            agg["accuracy_mean"].to_numpy(),
            yerr=agg["accuracy_std"].to_numpy(),
            marker="o",
        )
        plt.axhline(0.5, linestyle="--", linewidth=1.5)
        plt.ylim(0.0, 1.0)
        plt.xlabel("Delta (frames)")
        plt.ylabel("Accuracy")
        plt.title("Aggregate Accuracy vs Delta (Primary: concat)")
        plt.grid(True, alpha=0.25)
        plt.tight_layout()
        plt.savefig(osp.join(args.output_dir, "aggregate_accuracy_vs_delta.png"), dpi=200)
        plt.close()

        plt.figure(figsize=(7, 4))
        plt.errorbar(
            agg["delta"].to_numpy(),
            agg["auc_mean"].to_numpy(),
            yerr=agg["auc_std"].to_numpy(),
            marker="o",
        )
        plt.axhline(0.5, linestyle="--", linewidth=1.5)
        plt.ylim(0.0, 1.0)
        plt.xlabel("Delta (frames)")
        plt.ylabel("ROC-AUC")
        plt.title("Aggregate AUC vs Delta (Primary: concat)")
        plt.grid(True, alpha=0.25)
        plt.tight_layout()
        plt.savefig(osp.join(args.output_dir, "aggregate_auc_vs_delta.png"), dpi=200)
        plt.close()

    print(f"\nSaved outputs to: {args.output_dir}")


if __name__ == "__main__":
    main()
Editor is loading...
Leave a Comment