Untitled

 avatar
unknown
plain_text
15 days ago
25 kB
4
Indexable
import os
import sys
import numpy as np
from PIL import Image
import torch
import cv2
from tqdm import tqdm  # For progress bars
import argparse  # For command-line argument parsing
import logging  # For logging with timestamps

# ---------------------------------------------------------------------------
# 1) DepthAnythingV2 configurations and loader
# ---------------------------------------------------------------------------
from depth_anything_v2.dpt import DepthAnythingV2

# Model configurations
model_configs = {
    'vits': {
        'encoder': 'vits',
        'features': 64,
        'out_channels': [48, 96, 192, 384]
    },
    'vitb': {
        'encoder': 'vitb',
        'features': 128,
        'out_channels': [96, 192, 384, 768]
    },
    'vitl': {
        'encoder': 'vitl',
        'features': 256,
        'out_channels': [256, 512, 1024, 1024]
    },
    'vitg': {
        'encoder': 'vitg',
        'features': 384,
        'out_channels': [1536, 1536, 1536, 1536]
    }
}

def load_depth_anything_model(encoder='vitl'):
    """
    Loads one of the Depth Anything V2 models: 'vits', 'vitb', 'vitl', 'vitg'.
    Defaults to 'vitl' (Large) if none is specified.
    """
    logging.info(f"Loading Depth Anything V2 {encoder.upper()} model...")
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

    if encoder not in model_configs:
        logging.error(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}")
        raise ValueError(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}")

    config = model_configs[encoder]
    model = DepthAnythingV2(**config)

    # Example checkpoint path: "checkpoints/depth_anything_v2_vitb.pth"
    checkpoint_path = f'checkpoints/depth_anything_v2_{encoder}.pth'
    if not os.path.exists(checkpoint_path):
        logging.error(f"Checkpoint not found at: {checkpoint_path}")
        raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")

    # Load weights
    logging.info(f"Using checkpoint: {checkpoint_path}")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model = model.to(device).eval()

    logging.info(f"Model {encoder.upper()} loaded successfully on {device}.")
    return model, device

# ---------------------------------------------------------------------------
# 2) Depth inference for a cropped tile
# ---------------------------------------------------------------------------
def estimate_depth(tile_img, model, device, input_size=518, tile_size=512):
    """
    Estimates depth for a single cropped tile using Depth Anything V2.
    - Resizes the tile image to `input_size` before inference.
    - Resizes the depth map back to `tile_size`.
    
    Returns:
        depth_map_resized (numpy.ndarray): (tile_size, tile_size) float32 array.
    """
    # Resize the tile to input_size for model inference
    tile_pil = tile_img.resize((input_size, input_size), Image.BICUBIC)
    tile_cv = cv2.cvtColor(np.array(tile_pil), cv2.COLOR_RGB2BGR)

    with torch.no_grad():
        depth_map = model.infer_image(tile_cv, input_size)  # float32 array of shape (input_size, input_size)
    # Resize the depth map back to tile_size
    depth_map_pil = Image.fromarray(depth_map)
    depth_map_resized = np.array(depth_map_pil.resize((tile_size, tile_size), Image.BICUBIC)).astype(np.float32)

    return depth_map_resized

# ---------------------------------------------------------------------------
# 3) Utility to save 8-bit grayscale images
# ---------------------------------------------------------------------------
def save_image(array, path):
    """
    Save a 2D float or int array as an 8-bit PNG (grayscale).
    """
    array_8u = np.clip(array, 0, 255).astype('uint8')
    img = Image.fromarray(array_8u, mode='L')
    img.save(path)
    logging.info(f"Saved image to {path}")

# ---------------------------------------------------------------------------
# 4) Process a single tile across all images with "top N smallest depth" logic
# ---------------------------------------------------------------------------
def process_tile(
    tile_position,
    input_directory,
    processed_tiles_directory,
    model,
    device,
    tile_size=512,
    top_n=10,
    percentiles=[40],
    input_size=518,
    save_cropped=False,
    cropped_dir=None,
    save_depth_maps=False,
    depth_maps_dir=None
):
    """
    For each tile:
      - Crops that tile from each frame
      - Resizes the tile to `input_size` and estimates Depth
      - Collects grayscale
      - For each pixel, pick the top_n frames with the smallest depth
      - Computes the specified percentiles of those top_n grayscale intensities
      - Saves the resulting tiles for each percentile
      - Optionally saves cropped images and depth maps
    """
    i, j = tile_position
    grayscale_frames = []
    depth_frames = []
    frame_names = []  # To keep track of frame filenames for saving

    # Collect frames
    for filename in sorted(os.listdir(input_directory)):
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
            continue

        img_path = os.path.join(input_directory, filename)
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            logging.error(f"Failed to open {filename}: {e}")
            continue

        image_width, image_height = img.size
        # Check tile bounds
        if (j * tile_size + tile_size > image_width) or (i * tile_size + tile_size > image_height):
            # The tile extends beyond image boundaries
            if save_cropped and cropped_dir:
                # Create a new blank image and paste the available part
                new_tile = Image.new('RGB', (tile_size, tile_size), (0, 0, 0))
                left = j * tile_size
                upper = i * tile_size
                right = min(left + tile_size, image_width)
                lower = min(upper + tile_size, image_height)
                available_tile = img.crop((left, upper, right, lower))
                new_tile.paste(available_tile, (0, 0))
                # Save the padded cropped image
                cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
                cropped_path = os.path.join(cropped_dir, cropped_filename)
                new_tile.save(cropped_path)
                logging.info(f"Saved padded cropped image to {cropped_path}")
            logging.warning(f"Tile ({i}, {j}) out of bounds for image {filename} size {img.size}. Skipping.")
            continue

        # Crop the tile
        left = j * tile_size
        upper = i * tile_size
        right = left + tile_size
        lower = upper + tile_size
        bounding_box = (left, upper, right, lower)
        tile_img = img.crop(bounding_box)

        # Optionally save the cropped image
        if save_cropped and cropped_dir:
            cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
            cropped_path = os.path.join(cropped_dir, cropped_filename)
            tile_img.save(cropped_path)
            logging.info(f"Saved cropped image to {cropped_path}")

        # Convert to grayscale
        grayscale_tile = tile_img.convert('L')
        grayscale_array = np.array(grayscale_tile, dtype=float)
        grayscale_frames.append(grayscale_array)

        # Depth
        try:
            depth_map = estimate_depth(tile_img, model, device, input_size=input_size, tile_size=tile_size)
        except Exception as e:
            logging.error(f"Failed to estimate depth for tile ({i}, {j}) in {filename}: {e}")
            depth_map = np.full((tile_size, tile_size), np.nan)

        depth_frames.append(depth_map)
        frame_names.append(os.path.splitext(filename)[0])

        # Optionally save the depth map
        if save_depth_maps and depth_maps_dir:
            # Normalize the depth map for visualization
            if np.all(np.isnan(depth_map)):
                depth_normalized = np.zeros_like(depth_map)
            else:
                depth_min = np.nanmin(depth_map)
                depth_max = np.nanmax(depth_map)
                if depth_max - depth_min > 0:
                    depth_normalized = (depth_map - depth_min) / (depth_max - depth_min) * 255
                else:
                    depth_normalized = np.zeros_like(depth_map)
            depth_normalized = depth_normalized.astype('uint8')
            depth_img = Image.fromarray(depth_normalized, mode='L')
            depth_filename = f"depth_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
            depth_path = os.path.join(depth_maps_dir, depth_filename)
            depth_img.save(depth_path)
            logging.info(f"Saved depth map to {depth_path}")

    if not grayscale_frames:
        logging.warning(f"No frames processed for tile ({i}, {j}).")
        return

    # -------------------------
    # GPU-based pixel stacking:
    # -------------------------
    # Stack frames into arrays of shape (num_frames, tile_size, tile_size)
    grayscale_stack = np.stack(grayscale_frames, axis=0)  # (N, H, W)
    depth_stack = np.stack(depth_frames, axis=0)            # (N, H, W)
    num_frames = grayscale_stack.shape[0]

    # Convert stacks to torch tensors and move to the same device as the model
    grayscale_tensor = torch.from_numpy(grayscale_stack).to(device)  # (N, H, W)
    depth_tensor = torch.from_numpy(depth_stack).to(device)          # (N, H, W)

    for percentile in percentiles:
        actual_top_n = min(top_n, num_frames)

        # Create a mask of valid (non-NaN) depth values and replace NaNs with +inf so they sort to the end
        valid_mask = ~torch.isnan(depth_tensor)
        depth_tensor_fixed = torch.where(valid_mask, depth_tensor, torch.tensor(float('inf'), device=device, dtype=depth_tensor.dtype))

        # Sort depths (and corresponding grayscale values) along the frame axis (dim=0)
        sorted_depth, sorted_indices = torch.sort(depth_tensor_fixed, dim=0)
        sorted_grayscale = torch.gather(grayscale_tensor, dim=0, index=sorted_indices)

        # Select only the first 'actual_top_n' frames per pixel
        selected = sorted_grayscale[:actual_top_n]  # (actual_top_n, tile_size, tile_size)

        # Count how many valid depth values there are per pixel
        valid_count = valid_mask.sum(dim=0)  # (tile_size, tile_size) as int tensor
        # For each pixel, use the smaller of valid_count and actual_top_n
        effective_k = torch.minimum(valid_count, torch.tensor(actual_top_n, device=device, dtype=valid_count.dtype))

        # Compute the quantile index for each pixel
        q = percentile / 100.0
        effective_k_float = effective_k.to(torch.float32)
        # For pixels with at least one valid value, compute: index = (n - 1) * q
        idx_float = (effective_k_float - 1) * q  # (tile_size, tile_size)
        floor_idx = idx_float.floor().long()       # (tile_size, tile_size)
        ceil_idx = idx_float.ceil().long()           # (tile_size, tile_size)

        # Clamp indices so they do not exceed the number of valid values (minus 1)
        effective_k_minus_one = (effective_k - 1).clamp(min=0)
        clamped_floor = torch.min(floor_idx, effective_k_minus_one)
        clamped_ceil = torch.min(ceil_idx, effective_k_minus_one)

        # Prepare grid indices for the spatial dimensions
        H, W = tile_size, tile_size
        grid_h = torch.arange(H, device=device).view(H, 1).expand(H, W)
        grid_w = torch.arange(W, device=device).view(1, W).expand(H, W)

        # Gather the lower and upper grayscale values for interpolation
        lower_val = selected[clamped_floor, grid_h, grid_w]  # (H, W)
        upper_val = selected[clamped_ceil, grid_h, grid_w]     # (H, W)

        # Linear interpolation weight
        weight = (idx_float - clamped_floor.to(torch.float32))
        result = lower_val + (upper_val - lower_val) * weight

        # For pixels with no valid depth (effective_k == 0), set the output to 0.
        result = torch.where(effective_k > 0, result, torch.zeros_like(result))

        # Convert result back to a NumPy array (on CPU)
        out_tile = result.cpu().numpy()

        # Save the output tile image
        processed_tile_path = os.path.join(
            processed_tiles_directory,
            f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png"
        )
        save_image(out_tile, processed_tile_path)

# ---------------------------------------------------------------------------
# 5) Reconstruct the full image from processed tiles
# ---------------------------------------------------------------------------
def reconstruct_full_image(
    processed_tiles_directory,
    output_image_path,
    tiles_shape,
    tile_size=512,
    top_n=5,
    percentile=40.0
):
    """
    Reconstructs the full image by stitching together processed tiles for a specific percentile.
    """
    tiles_y, tiles_x = tiles_shape
    full_height = tiles_y * tile_size
    full_width  = tiles_x * tile_size
    full_image = Image.new('L', (full_width, full_height))  # Assuming grayscale output
    logging.info(f"Creating a blank full image of size: {full_width}x{full_height} pixels")

    for i in tqdm(range(tiles_y), desc=f"Reconstructing image for p{percentile}"):
        for j in range(tiles_x):
            # Construct the expected tile filename
            tile_filename = f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png"
            tile_path = os.path.join(processed_tiles_directory, tile_filename)
            if not os.path.exists(tile_path):
                logging.warning(f"Processed tile ({i}, {j}) not found at {tile_path}. Filling with black.")
                # Create a black tile
                tile_image = Image.new('L', (tile_size, tile_size), 0)
            else:
                try:
                    tile_image = Image.open(tile_path).convert('L')
                    logging.info(f"Loaded processed tile from {tile_path}")
                except Exception as e:
                    logging.error(f"Failed to open processed tile {tile_path}: {e}")
                    # Create a black tile in case of error
                    tile_image = Image.new('L', (tile_size, tile_size), 0)

            # Ensure tile size matches
            if tile_image.size != (tile_size, tile_size):
                logging.warning(f"Tile {tile_filename} has size {tile_image.size}, resizing to ({tile_size}, {tile_size})")
                tile_image = tile_image.resize((tile_size, tile_size), Image.BICUBIC)

            # Paste the tile into the full image
            full_image.paste(tile_image, (j * tile_size, i * tile_size))

    # Save the final reconstructed image
    try:
        full_image.save(output_image_path)
        logging.info(f"Reconstructed full image saved to {output_image_path}")
    except Exception as e:
        logging.error(f"Failed to save full image to {output_image_path}: {e}")

# ---------------------------------------------------------------------------
# 6) Main Function
# ---------------------------------------------------------------------------
def main():
    """
    Main function to handle argument parsing, processing, and reconstruction.
    """
    # Initialize logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout)
        ]
    )

    parser = argparse.ArgumentParser(description="Percentile Stacking for Beeless Image Generation")

    # Optional arguments for processing a single tile
    parser.add_argument(
        '--tile_i',
        type=int,
        help="Tile row index (e.g., 4)"
    )
    parser.add_argument(
        '--tile_j',
        type=int,
        help="Tile column index (e.g., 8)"
    )

    # Other optional parameters
    parser.add_argument(
        '--encoder',
        type=str,
        default='vitl',
        choices=model_configs.keys(),
        help="Depth Anything V2 encoder model to use (default: 'vitl')"
    )
    parser.add_argument(
        '--input_directory',
        type=str,
        default="percentile_stacking_22_12_2024_41images",
        help="Directory containing input frames"
    )
    parser.add_argument(
        '--processed_tiles_directory',
        type=str,
        default="good_depth_anything_large_filtered_stack_filtertop5_40thpercentile/processed_tiles",
        help="Directory to store processed tiles"
    )
    parser.add_argument(
        '--output_image_path',
        type=str,
        default="good_depth_anything_large_filtered_stack_filtertop5/output_images/clean_full_image_top5_p40.png",
        help="Base path to save the final reconstructed images (must include file extension, e.g., '.png')"
    )
    parser.add_argument(
        '--tile_size',
        type=int,
        default=512,
        help="Size of each tile (default: 512)"
    )
    parser.add_argument(
        '--top_n',
        type=int,
        default=5,
        help="Number of top frames to consider based on depth (default: 5)"
    )
    parser.add_argument(
        '--percentiles',
        type=float,
        nargs='+',
        default=[40.0],
        help="List of desired percentiles for stacking (default: [40.0])"
    )
    parser.add_argument(
        '--input_size',
        type=int,
        default=518,
        help="Input size for model inference (default: 518)"
    )
    # New arguments for saving cropped images and depth maps
    parser.add_argument(
        '--save_cropped',
        action='store_true',
        help="Flag to save cropped tile images"
    )
    parser.add_argument(
        '--cropped_dir',
        type=str,
        default=None,
        help="Directory to save cropped tile images (required if --save_cropped is set)"
    )
    parser.add_argument(
        '--save_depth_maps',
        action='store_true',
        help="Flag to save depth maps as images"
    )
    parser.add_argument(
        '--depth_maps_dir',
        type=str,
        default=None,
        help="Directory to save depth map images (required if --save_depth_maps is set)"
    )

    args = parser.parse_args()

    # Validate output_image_path has a valid extension
    valid_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif']
    base_output, ext = os.path.splitext(args.output_image_path)
    if ext.lower() not in valid_extensions:
        logging.error(f"Invalid or missing file extension in output_image_path: '{args.output_image_path}'. "
                      f"Please use one of the following extensions: {', '.join(valid_extensions)}")
        sys.exit(1)

    # Validate arguments related to saving cropped images and depth maps
    if args.save_cropped and not args.cropped_dir:
        parser.error("--save_cropped requires --cropped_dir to be specified.")

    if args.save_depth_maps and not args.depth_maps_dir:
        parser.error("--save_depth_maps requires --depth_maps_dir to be specified.")

    # Create directories if they don't exist
    os.makedirs(args.processed_tiles_directory, exist_ok=True)
    os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
    if args.save_cropped:
        os.makedirs(args.cropped_dir, exist_ok=True)
    if args.save_depth_maps:
        os.makedirs(args.depth_maps_dir, exist_ok=True)

    # Validate percentiles
    for percentile in args.percentiles:
        if not (0 <= percentile <= 100):
            logging.error(f"Percentile '{percentile}' is out of bounds. Must be between 0 and 100.")
            raise ValueError("Percentiles must be between 0 and 100.")
        
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        print(f"Using GPU: {gpu_name}")
    else:
        print("No GPU available.")

    # Load model
    try:
        model, device = load_depth_anything_model(encoder=args.encoder)
    except Exception as e:
        logging.critical(f"Failed to load model: {e}")
        sys.exit(1)

    # Determine image dims & tile grid from first valid image
    sample_image_path = None
    for fname in sorted(os.listdir(args.input_directory)):
        if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
            sample_image_path = os.path.join(args.input_directory, fname)
            break

    if not sample_image_path:
        logging.critical(f"No valid images found in {args.input_directory}")
        sys.exit(1)

    try:
        sample_img = Image.open(sample_image_path).convert('RGB')
        img_width, img_height = sample_img.size
    except Exception as e:
        logging.critical(f"Failed to open sample image {sample_image_path}: {e}")
        sys.exit(1)

    tiles_x = img_width // args.tile_size
    tiles_y = img_height // args.tile_size
    tiles_shape = (tiles_y, tiles_x)

    logging.info(f"Image size: {img_width} x {img_height}")
    logging.info(f"Tiles: {tiles_y} rows x {tiles_x} columns, tile_size={args.tile_size}")
    logging.info(f"Model input size: {args.input_size}")
    logging.info(f"Percentiles to process: {args.percentiles}")

    # Phase 1: Process tiles
    if args.tile_i is not None and args.tile_j is not None:
        # Single Tile Mode
        i = args.tile_i
        j = args.tile_j

        if i < 0 or i >= tiles_y or j < 0 or j >= tiles_x:
            logging.error(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}")
            raise ValueError(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}")

        logging.info(f"Processing single tile ({i}, {j}) with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...")
        process_tile(
            tile_position=(i, j),
            input_directory=args.input_directory,
            processed_tiles_directory=args.processed_tiles_directory,
            model=model,
            device=device,
            tile_size=args.tile_size,
            top_n=args.top_n,
            percentiles=args.percentiles,
            input_size=args.input_size,
            save_cropped=args.save_cropped,
            cropped_dir=args.cropped_dir,
            save_depth_maps=args.save_depth_maps,
            depth_maps_dir=args.depth_maps_dir
        )
        logging.info(f"Tile ({i}, {j}) processed successfully.")
    else:
        # Batch Mode: Process all tiles
        logging.info(f"Processing all tiles with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...")
        for i in tqdm(range(tiles_y), desc="Processing tiles row-wise"):
            for j in range(tiles_x):
                process_tile(
                    tile_position=(i, j),
                    input_directory=args.input_directory,
                    processed_tiles_directory=args.processed_tiles_directory,
                    model=model,
                    device=device,
                    tile_size=args.tile_size,
                    top_n=args.top_n,
                    percentiles=args.percentiles,
                    input_size=args.input_size,
                    save_cropped=args.save_cropped,
                    cropped_dir=args.cropped_dir,
                    save_depth_maps=args.save_depth_maps,
                    depth_maps_dir=args.depth_maps_dir
                )
        logging.info("All tiles processed.")

        # Phase 2: Reconstruct full images for each percentile
        for percentile in args.percentiles:
            # Define output image path for the current percentile
            base_output, ext = os.path.splitext(args.output_image_path)
            percentile_suffix = f"_p{percentile}"
            output_image_path = f"{base_output}{percentile_suffix}{ext}"

            logging.info(f"Reconstructing full image for percentile {percentile}...")
            reconstruct_full_image(
                processed_tiles_directory=args.processed_tiles_directory,
                output_image_path=output_image_path,
                tiles_shape=tiles_shape,
                tile_size=args.tile_size,
                top_n=args.top_n,
                percentile=percentile
            )
            logging.info(f"Reconstruction for percentile {percentile} complete. Saved to {output_image_path}")

    logging.info("All processes completed successfully.")

if __name__ == "__main__":
    main()
Editor is loading...
Leave a Comment