Untitled
import os import sys import numpy as np from PIL import Image import torch import cv2 from tqdm import tqdm # For progress bars import argparse # For command-line argument parsing import logging # For logging with timestamps # --------------------------------------------------------------------------- # 1) DepthAnythingV2 configurations and loader # --------------------------------------------------------------------------- from depth_anything_v2.dpt import DepthAnythingV2 # Model configurations model_configs = { 'vits': { 'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384] }, 'vitb': { 'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768] }, 'vitl': { 'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024] }, 'vitg': { 'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536] } } def load_depth_anything_model(encoder='vitl'): """ Loads one of the Depth Anything V2 models: 'vits', 'vitb', 'vitl', 'vitg'. Defaults to 'vitl' (Large) if none is specified. """ logging.info(f"Loading Depth Anything V2 {encoder.upper()} model...") device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' if encoder not in model_configs: logging.error(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}") raise ValueError(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}") config = model_configs[encoder] model = DepthAnythingV2(**config) # Example checkpoint path: "checkpoints/depth_anything_v2_vitb.pth" checkpoint_path = f'checkpoints/depth_anything_v2_{encoder}.pth' if not os.path.exists(checkpoint_path): logging.error(f"Checkpoint not found at: {checkpoint_path}") raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}") # Load weights logging.info(f"Using checkpoint: {checkpoint_path}") model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model = model.to(device).eval() logging.info(f"Model {encoder.upper()} loaded successfully on {device}.") return model, device # --------------------------------------------------------------------------- # 2) Depth inference for a cropped tile # --------------------------------------------------------------------------- def estimate_depth(tile_img, model, device, input_size=518, tile_size=512): """ Estimates depth for a single cropped tile using Depth Anything V2. - Resizes the tile image to `input_size` before inference. - Resizes the depth map back to `tile_size`. Returns: depth_map_resized (numpy.ndarray): (tile_size, tile_size) float32 array. """ # Resize the tile to input_size for model inference tile_pil = tile_img.resize((input_size, input_size), Image.BICUBIC) tile_cv = cv2.cvtColor(np.array(tile_pil), cv2.COLOR_RGB2BGR) with torch.no_grad(): depth_map = model.infer_image(tile_cv, input_size) # float32 array of shape (input_size, input_size) # Resize the depth map back to tile_size depth_map_pil = Image.fromarray(depth_map) depth_map_resized = np.array(depth_map_pil.resize((tile_size, tile_size), Image.BICUBIC)).astype(np.float32) return depth_map_resized # --------------------------------------------------------------------------- # 3) Utility to save 8-bit grayscale images # --------------------------------------------------------------------------- def save_image(array, path): """ Save a 2D float or int array as an 8-bit PNG (grayscale). """ array_8u = np.clip(array, 0, 255).astype('uint8') img = Image.fromarray(array_8u, mode='L') img.save(path) logging.info(f"Saved image to {path}") # --------------------------------------------------------------------------- # 4) Process a single tile across all images with "top N smallest depth" logic # --------------------------------------------------------------------------- def process_tile( tile_position, input_directory, processed_tiles_directory, model, device, tile_size=512, top_n=10, percentiles=[40], input_size=518, save_cropped=False, cropped_dir=None, save_depth_maps=False, depth_maps_dir=None ): """ For each tile: - Crops that tile from each frame - Resizes the tile to `input_size` and estimates Depth - Collects grayscale - For each pixel, pick the top_n frames with the smallest depth - Computes the specified percentiles of those top_n grayscale intensities - Saves the resulting tiles for each percentile - Optionally saves cropped images and depth maps """ i, j = tile_position grayscale_frames = [] depth_frames = [] frame_names = [] # To keep track of frame filenames for saving # Collect frames for filename in sorted(os.listdir(input_directory)): if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')): continue img_path = os.path.join(input_directory, filename) try: img = Image.open(img_path).convert('RGB') except Exception as e: logging.error(f"Failed to open {filename}: {e}") continue image_width, image_height = img.size # Check tile bounds if (j * tile_size + tile_size > image_width) or (i * tile_size + tile_size > image_height): # The tile extends beyond image boundaries if save_cropped and cropped_dir: # Create a new blank image and paste the available part new_tile = Image.new('RGB', (tile_size, tile_size), (0, 0, 0)) left = j * tile_size upper = i * tile_size right = min(left + tile_size, image_width) lower = min(upper + tile_size, image_height) available_tile = img.crop((left, upper, right, lower)) new_tile.paste(available_tile, (0, 0)) # Save the padded cropped image cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png" cropped_path = os.path.join(cropped_dir, cropped_filename) new_tile.save(cropped_path) logging.info(f"Saved padded cropped image to {cropped_path}") logging.warning(f"Tile ({i}, {j}) out of bounds for image {filename} size {img.size}. Skipping.") continue # Crop the tile left = j * tile_size upper = i * tile_size right = left + tile_size lower = upper + tile_size bounding_box = (left, upper, right, lower) tile_img = img.crop(bounding_box) # Optionally save the cropped image if save_cropped and cropped_dir: cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png" cropped_path = os.path.join(cropped_dir, cropped_filename) tile_img.save(cropped_path) logging.info(f"Saved cropped image to {cropped_path}") # Convert to grayscale grayscale_tile = tile_img.convert('L') grayscale_array = np.array(grayscale_tile, dtype=float) grayscale_frames.append(grayscale_array) # Depth try: depth_map = estimate_depth(tile_img, model, device, input_size=input_size, tile_size=tile_size) except Exception as e: logging.error(f"Failed to estimate depth for tile ({i}, {j}) in {filename}: {e}") depth_map = np.full((tile_size, tile_size), np.nan) depth_frames.append(depth_map) frame_names.append(os.path.splitext(filename)[0]) # Optionally save the depth map if save_depth_maps and depth_maps_dir: # Normalize the depth map for visualization if np.all(np.isnan(depth_map)): depth_normalized = np.zeros_like(depth_map) else: depth_min = np.nanmin(depth_map) depth_max = np.nanmax(depth_map) if depth_max - depth_min > 0: depth_normalized = (depth_map - depth_min) / (depth_max - depth_min) * 255 else: depth_normalized = np.zeros_like(depth_map) depth_normalized = depth_normalized.astype('uint8') depth_img = Image.fromarray(depth_normalized, mode='L') depth_filename = f"depth_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png" depth_path = os.path.join(depth_maps_dir, depth_filename) depth_img.save(depth_path) logging.info(f"Saved depth map to {depth_path}") if not grayscale_frames: logging.warning(f"No frames processed for tile ({i}, {j}).") return # ------------------------- # GPU-based pixel stacking (modified): # ------------------------- # Stack frames into arrays of shape (num_frames, tile_size, tile_size) grayscale_stack = np.stack(grayscale_frames, axis=0) # (N, H, W) depth_stack = np.stack(depth_frames, axis=0) # (N, H, W) num_frames = grayscale_stack.shape[0] # Convert stacks to torch tensors and move to the same device as the model grayscale_tensor = torch.from_numpy(grayscale_stack).to(device) # (N, H, W) depth_tensor = torch.from_numpy(depth_stack).to(device) # (N, H, W) for percentile in percentiles: actual_top_n = min(top_n, num_frames) # Create a mask of valid (non-NaN) depth values valid_mask = ~torch.isnan(depth_tensor) # Replace NaNs with +inf so they sort to the end depth_tensor_fixed = torch.where(valid_mask, depth_tensor, torch.tensor(float('inf'), device=device, dtype=depth_tensor.dtype)) # Sort depths (and corresponding grayscale values) along the frame axis (dim=0) sorted_depth, sorted_indices = torch.sort(depth_tensor_fixed, dim=0) sorted_grayscale = torch.gather(grayscale_tensor, 0, sorted_indices) # Take only the first 'actual_top_n' frames per pixel selected = sorted_grayscale[:actual_top_n, :, :].clone() # shape: (actual_top_n, tile_size, tile_size) # Compute the number of valid entries per pixel, but only up to actual_top_n valid_count = valid_mask.sum(dim=0) # shape: (tile_size, tile_size) effective_n = torch.minimum(valid_count, torch.tensor(actual_top_n, device=device, dtype=valid_count.dtype)) # For pixels with fewer than 'actual_top_n' valid values, mark extra candidate entries as NaN candidate_indices = torch.arange(actual_top_n, device=device).view(actual_top_n, 1, 1).expand(actual_top_n, tile_size, tile_size) mask = candidate_indices >= effective_n.unsqueeze(0) selected[mask] = float('nan') # Compute the quantile using torch.nanquantile (which ignores NaNs) q = percentile / 100.0 result = torch.nanquantile(selected, q, dim=0, interpolation='linear') # For pixels with no valid depth values, set output to 0 result = torch.where(effective_n > 0, result, torch.zeros_like(result)) out_tile = result.cpu().numpy() # Save the output tile image processed_tile_path = os.path.join( processed_tiles_directory, f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png" ) save_image(out_tile, processed_tile_path) # --------------------------------------------------------------------------- # 5) Reconstruct the full image from processed tiles # --------------------------------------------------------------------------- def reconstruct_full_image( processed_tiles_directory, output_image_path, tiles_shape, tile_size=512, top_n=5, percentile=40.0 ): """ Reconstructs the full image by stitching together processed tiles for a specific percentile. """ tiles_y, tiles_x = tiles_shape full_height = tiles_y * tile_size full_width = tiles_x * tile_size full_image = Image.new('L', (full_width, full_height)) # Assuming grayscale output logging.info(f"Creating a blank full image of size: {full_width}x{full_height} pixels") for i in tqdm(range(tiles_y), desc=f"Reconstructing image for p{percentile}"): for j in range(tiles_x): # Construct the expected tile filename tile_filename = f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png" tile_path = os.path.join(processed_tiles_directory, tile_filename) if not os.path.exists(tile_path): logging.warning(f"Processed tile ({i}, {j}) not found at {tile_path}. Filling with black.") # Create a black tile tile_image = Image.new('L', (tile_size, tile_size), 0) else: try: tile_image = Image.open(tile_path).convert('L') logging.info(f"Loaded processed tile from {tile_path}") except Exception as e: logging.error(f"Failed to open processed tile {tile_path}: {e}") # Create a black tile in case of error tile_image = Image.new('L', (tile_size, tile_size), 0) # Ensure tile size matches if tile_image.size != (tile_size, tile_size): logging.warning(f"Tile {tile_filename} has size {tile_image.size}, resizing to ({tile_size}, {tile_size})") tile_image = tile_image.resize((tile_size, tile_size), Image.BICUBIC) # Paste the tile into the full image full_image.paste(tile_image, (j * tile_size, i * tile_size)) # Save the final reconstructed image try: full_image.save(output_image_path) logging.info(f"Reconstructed full image saved to {output_image_path}") except Exception as e: logging.error(f"Failed to save full image to {output_image_path}: {e}") # --------------------------------------------------------------------------- # 6) Main Function # --------------------------------------------------------------------------- def main(): """ Main function to handle argument parsing, processing, and reconstruction. """ # Initialize logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) parser = argparse.ArgumentParser(description="Percentile Stacking for Beeless Image Generation") # Optional arguments for processing a single tile parser.add_argument( '--tile_i', type=int, help="Tile row index (e.g., 4)" ) parser.add_argument( '--tile_j', type=int, help="Tile column index (e.g., 8)" ) # Other optional parameters parser.add_argument( '--encoder', type=str, default='vitl', choices=model_configs.keys(), help="Depth Anything V2 encoder model to use (default: 'vitl')" ) parser.add_argument( '--input_directory', type=str, default="percentile_stacking_22_12_2024_41images", help="Directory containing input frames" ) parser.add_argument( '--processed_tiles_directory', type=str, default="good_depth_anything_large_filtered_stack_filtertop5_40thpercentile/processed_tiles", help="Directory to store processed tiles" ) parser.add_argument( '--output_image_path', type=str, default="good_depth_anything_large_filtered_stack_filtertop5/output_images/clean_full_image_top5_p40.png", help="Base path to save the final reconstructed images (must include file extension, e.g., '.png')" ) parser.add_argument( '--tile_size', type=int, default=512, help="Size of each tile (default: 512)" ) parser.add_argument( '--top_n', type=int, default=5, help="Number of top frames to consider based on depth (default: 5)" ) parser.add_argument( '--percentiles', type=float, nargs='+', default=[40.0], help="List of desired percentiles for stacking (default: [40.0])" ) parser.add_argument( '--input_size', type=int, default=518, help="Input size for model inference (default: 518)" ) # New arguments for saving cropped images and depth maps parser.add_argument( '--save_cropped', action='store_true', help="Flag to save cropped tile images" ) parser.add_argument( '--cropped_dir', type=str, default=None, help="Directory to save cropped tile images (required if --save_cropped is set)" ) parser.add_argument( '--save_depth_maps', action='store_true', help="Flag to save depth maps as images" ) parser.add_argument( '--depth_maps_dir', type=str, default=None, help="Directory to save depth map images (required if --save_depth_maps is set)" ) args = parser.parse_args() # Validate output_image_path has a valid extension valid_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'] base_output, ext = os.path.splitext(args.output_image_path) if ext.lower() not in valid_extensions: logging.error(f"Invalid or missing file extension in output_image_path: '{args.output_image_path}'. " f"Please use one of the following extensions: {', '.join(valid_extensions)}") sys.exit(1) # Validate arguments related to saving cropped images and depth maps if args.save_cropped and not args.cropped_dir: parser.error("--save_cropped requires --cropped_dir to be specified.") if args.save_depth_maps and not args.depth_maps_dir: parser.error("--save_depth_maps requires --depth_maps_dir to be specified.") # Create directories if they don't exist os.makedirs(args.processed_tiles_directory, exist_ok=True) os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True) if args.save_cropped: os.makedirs(args.cropped_dir, exist_ok=True) if args.save_depth_maps: os.makedirs(args.depth_maps_dir, exist_ok=True) # Validate percentiles for percentile in args.percentiles: if not (0 <= percentile <= 100): logging.error(f"Percentile '{percentile}' is out of bounds. Must be between 0 and 100.") raise ValueError("Percentiles must be between 0 and 100.") if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) print(f"Using GPU: {gpu_name}") else: print("No GPU available.") # Load model try: model, device = load_depth_anything_model(encoder=args.encoder) except Exception as e: logging.critical(f"Failed to load model: {e}") sys.exit(1) # Determine image dims & tile grid from first valid image sample_image_path = None for fname in sorted(os.listdir(args.input_directory)): if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')): sample_image_path = os.path.join(args.input_directory, fname) break if not sample_image_path: logging.critical(f"No valid images found in {args.input_directory}") sys.exit(1) try: sample_img = Image.open(sample_image_path).convert('RGB') img_width, img_height = sample_img.size except Exception as e: logging.critical(f"Failed to open sample image {sample_image_path}: {e}") sys.exit(1) tiles_x = img_width // args.tile_size tiles_y = img_height // args.tile_size tiles_shape = (tiles_y, tiles_x) logging.info(f"Image size: {img_width} x {img_height}") logging.info(f"Tiles: {tiles_y} rows x {tiles_x} columns, tile_size={args.tile_size}") logging.info(f"Model input size: {args.input_size}") logging.info(f"Percentiles to process: {args.percentiles}") # Phase 1: Process tiles if args.tile_i is not None and args.tile_j is not None: # Single Tile Mode i = args.tile_i j = args.tile_j if i < 0 or i >= tiles_y or j < 0 or j >= tiles_x: logging.error(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}") raise ValueError(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}") logging.info(f"Processing single tile ({i}, {j}) with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...") process_tile( tile_position=(i, j), input_directory=args.input_directory, processed_tiles_directory=args.processed_tiles_directory, model=model, device=device, tile_size=args.tile_size, top_n=args.top_n, percentiles=args.percentiles, input_size=args.input_size, save_cropped=args.save_cropped, cropped_dir=args.cropped_dir, save_depth_maps=args.save_depth_maps, depth_maps_dir=args.depth_maps_dir ) logging.info(f"Tile ({i}, {j}) processed successfully.") else: # Batch Mode: Process all tiles logging.info(f"Processing all tiles with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...") for i in tqdm(range(tiles_y), desc="Processing tiles row-wise"): for j in range(tiles_x): process_tile( tile_position=(i, j), input_directory=args.input_directory, processed_tiles_directory=args.processed_tiles_directory, model=model, device=device, tile_size=args.tile_size, top_n=args.top_n, percentiles=args.percentiles, input_size=args.input_size, save_cropped=args.save_cropped, cropped_dir=args.cropped_dir, save_depth_maps=args.save_depth_maps, depth_maps_dir=args.depth_maps_dir ) logging.info("All tiles processed.") # Phase 2: Reconstruct full images for each percentile for percentile in args.percentiles: # Define output image path for the current percentile base_output, ext = os.path.splitext(args.output_image_path) percentile_suffix = f"_p{percentile}" output_image_path = f"{base_output}{percentile_suffix}{ext}" logging.info(f"Reconstructing full image for percentile {percentile}...") reconstruct_full_image( processed_tiles_directory=args.processed_tiles_directory, output_image_path=output_image_path, tiles_shape=tiles_shape, tile_size=args.tile_size, top_n=args.top_n, percentile=percentile ) logging.info(f"Reconstruction for percentile {percentile} complete. Saved to {output_image_path}") logging.info("All processes completed successfully.") if __name__ == "__main__": main()
Leave a Comment