Untitled
unknown
plain_text
9 months ago
24 kB
12
Indexable
import os
import sys
import numpy as np
from PIL import Image
import torch
import cv2
from tqdm import tqdm # For progress bars
import argparse # For command-line argument parsing
import logging # For logging with timestamps
# ---------------------------------------------------------------------------
# 1) DepthAnythingV2 configurations and loader
# ---------------------------------------------------------------------------
from depth_anything_v2.dpt import DepthAnythingV2
# Model configurations
model_configs = {
'vits': {
'encoder': 'vits',
'features': 64,
'out_channels': [48, 96, 192, 384]
},
'vitb': {
'encoder': 'vitb',
'features': 128,
'out_channels': [96, 192, 384, 768]
},
'vitl': {
'encoder': 'vitl',
'features': 256,
'out_channels': [256, 512, 1024, 1024]
},
'vitg': {
'encoder': 'vitg',
'features': 384,
'out_channels': [1536, 1536, 1536, 1536]
}
}
def load_depth_anything_model(encoder='vitl'):
"""
Loads one of the Depth Anything V2 models: 'vits', 'vitb', 'vitl', 'vitg'.
Defaults to 'vitl' (Large) if none is specified.
"""
logging.info(f"Loading Depth Anything V2 {encoder.upper()} model...")
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
if encoder not in model_configs:
logging.error(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}")
raise ValueError(f"Invalid encoder '{encoder}'. Must be one of: {list(model_configs.keys())}")
config = model_configs[encoder]
model = DepthAnythingV2(**config)
# Example checkpoint path: "checkpoints/depth_anything_v2_vitb.pth"
checkpoint_path = f'checkpoints/depth_anything_v2_{encoder}.pth'
if not os.path.exists(checkpoint_path):
logging.error(f"Checkpoint not found at: {checkpoint_path}")
raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")
# Load weights
logging.info(f"Using checkpoint: {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model = model.to(device).eval()
logging.info(f"Model {encoder.upper()} loaded successfully on {device}.")
return model, device
# ---------------------------------------------------------------------------
# 2) Depth inference for a cropped tile
# ---------------------------------------------------------------------------
def estimate_depth(tile_img, model, device, input_size=518, tile_size=512):
"""
Estimates depth for a single cropped tile using Depth Anything V2.
- Resizes the tile image to `input_size` before inference.
- Resizes the depth map back to `tile_size`.
Returns:
depth_map_resized (numpy.ndarray): (tile_size, tile_size) float32 array.
"""
# Resize the tile to input_size for model inference
tile_pil = tile_img.resize((input_size, input_size), Image.BICUBIC)
tile_cv = cv2.cvtColor(np.array(tile_pil), cv2.COLOR_RGB2BGR)
with torch.no_grad():
depth_map = model.infer_image(tile_cv, input_size) # float32 array of shape (input_size, input_size)
# Resize the depth map back to tile_size
depth_map_pil = Image.fromarray(depth_map)
depth_map_resized = np.array(depth_map_pil.resize((tile_size, tile_size), Image.BICUBIC)).astype(np.float32)
return depth_map_resized
# ---------------------------------------------------------------------------
# 3) Utility to save 8-bit grayscale images
# ---------------------------------------------------------------------------
def save_image(array, path):
"""
Save a 2D float or int array as an 8-bit PNG (grayscale).
"""
array_8u = np.clip(array, 0, 255).astype('uint8')
img = Image.fromarray(array_8u, mode='L')
img.save(path)
logging.info(f"Saved image to {path}")
# ---------------------------------------------------------------------------
# 4) Process a single tile across all images with "top N smallest depth" logic
# ---------------------------------------------------------------------------
def process_tile(
tile_position,
input_directory,
processed_tiles_directory,
model,
device,
tile_size=512,
top_n=10,
percentiles=[40],
input_size=518,
save_cropped=False,
cropped_dir=None,
save_depth_maps=False,
depth_maps_dir=None
):
"""
For each tile:
- Crops that tile from each frame
- Resizes the tile to `input_size` and estimates Depth
- Collects grayscale
- For each pixel, pick the top_n frames with the smallest depth
- Computes the specified percentiles of those top_n grayscale intensities
- Saves the resulting tiles for each percentile
- Optionally saves cropped images and depth maps
"""
i, j = tile_position
grayscale_frames = []
depth_frames = []
frame_names = [] # To keep track of frame filenames for saving
# Collect frames
for filename in sorted(os.listdir(input_directory)):
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
continue
img_path = os.path.join(input_directory, filename)
try:
img = Image.open(img_path).convert('RGB')
except Exception as e:
logging.error(f"Failed to open {filename}: {e}")
continue
image_width, image_height = img.size
# Check tile bounds
if (j * tile_size + tile_size > image_width) or (i * tile_size + tile_size > image_height):
# The tile extends beyond image boundaries
if save_cropped and cropped_dir:
# Create a new blank image and paste the available part
new_tile = Image.new('RGB', (tile_size, tile_size), (0, 0, 0))
left = j * tile_size
upper = i * tile_size
right = min(left + tile_size, image_width)
lower = min(upper + tile_size, image_height)
available_tile = img.crop((left, upper, right, lower))
new_tile.paste(available_tile, (0, 0))
# Save the padded cropped image
cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
cropped_path = os.path.join(cropped_dir, cropped_filename)
new_tile.save(cropped_path)
logging.info(f"Saved padded cropped image to {cropped_path}")
logging.warning(f"Tile ({i}, {j}) out of bounds for image {filename} size {img.size}. Skipping.")
continue
# Crop the tile
left = j * tile_size
upper = i * tile_size
right = left + tile_size
lower = upper + tile_size
bounding_box = (left, upper, right, lower)
tile_img = img.crop(bounding_box)
# Optionally save the cropped image
if save_cropped and cropped_dir:
cropped_filename = f"cropped_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
cropped_path = os.path.join(cropped_dir, cropped_filename)
tile_img.save(cropped_path)
logging.info(f"Saved cropped image to {cropped_path}")
# Convert to grayscale
grayscale_tile = tile_img.convert('L')
grayscale_array = np.array(grayscale_tile, dtype=float)
grayscale_frames.append(grayscale_array)
# Depth
try:
depth_map = estimate_depth(tile_img, model, device, input_size=input_size, tile_size=tile_size)
except Exception as e:
logging.error(f"Failed to estimate depth for tile ({i}, {j}) in {filename}: {e}")
depth_map = np.full((tile_size, tile_size), np.nan)
depth_frames.append(depth_map)
frame_names.append(os.path.splitext(filename)[0])
# Optionally save the depth map
if save_depth_maps and depth_maps_dir:
# Normalize the depth map for visualization
if np.all(np.isnan(depth_map)):
depth_normalized = np.zeros_like(depth_map)
else:
depth_min = np.nanmin(depth_map)
depth_max = np.nanmax(depth_map)
if depth_max - depth_min > 0:
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min) * 255
else:
depth_normalized = np.zeros_like(depth_map)
depth_normalized = depth_normalized.astype('uint8')
depth_img = Image.fromarray(depth_normalized, mode='L')
depth_filename = f"depth_tile_{i}_{j}_{os.path.splitext(filename)[0]}.png"
depth_path = os.path.join(depth_maps_dir, depth_filename)
depth_img.save(depth_path)
logging.info(f"Saved depth map to {depth_path}")
if not grayscale_frames:
logging.warning(f"No frames processed for tile ({i}, {j}).")
return
# -------------------------
# GPU-based pixel stacking (modified):
# -------------------------
# Stack frames into arrays of shape (num_frames, tile_size, tile_size)
grayscale_stack = np.stack(grayscale_frames, axis=0) # (N, H, W)
depth_stack = np.stack(depth_frames, axis=0) # (N, H, W)
num_frames = grayscale_stack.shape[0]
# Convert stacks to torch tensors and move to the same device as the model
grayscale_tensor = torch.from_numpy(grayscale_stack).to(device) # (N, H, W)
depth_tensor = torch.from_numpy(depth_stack).to(device) # (N, H, W)
for percentile in percentiles:
actual_top_n = min(top_n, num_frames)
# Create a mask of valid (non-NaN) depth values
valid_mask = ~torch.isnan(depth_tensor)
# Replace NaNs with +inf so they sort to the end
depth_tensor_fixed = torch.where(valid_mask, depth_tensor, torch.tensor(float('inf'), device=device, dtype=depth_tensor.dtype))
# Sort depths (and corresponding grayscale values) along the frame axis (dim=0)
sorted_depth, sorted_indices = torch.sort(depth_tensor_fixed, dim=0)
sorted_grayscale = torch.gather(grayscale_tensor, 0, sorted_indices)
# Take only the first 'actual_top_n' frames per pixel
selected = sorted_grayscale[:actual_top_n, :, :].clone() # shape: (actual_top_n, tile_size, tile_size)
# Compute the number of valid entries per pixel, but only up to actual_top_n
valid_count = valid_mask.sum(dim=0) # shape: (tile_size, tile_size)
effective_n = torch.minimum(valid_count, torch.tensor(actual_top_n, device=device, dtype=valid_count.dtype))
# For pixels with fewer than 'actual_top_n' valid values, mark extra candidate entries as NaN
candidate_indices = torch.arange(actual_top_n, device=device).view(actual_top_n, 1, 1).expand(actual_top_n, tile_size, tile_size)
mask = candidate_indices >= effective_n.unsqueeze(0)
selected[mask] = float('nan')
# Compute the quantile using torch.nanquantile (which ignores NaNs)
q = percentile / 100.0
result = torch.nanquantile(selected, q, dim=0, interpolation='linear')
# For pixels with no valid depth values, set output to 0
result = torch.where(effective_n > 0, result, torch.zeros_like(result))
out_tile = result.cpu().numpy()
# Save the output tile image
processed_tile_path = os.path.join(
processed_tiles_directory,
f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png"
)
save_image(out_tile, processed_tile_path)
# ---------------------------------------------------------------------------
# 5) Reconstruct the full image from processed tiles
# ---------------------------------------------------------------------------
def reconstruct_full_image(
processed_tiles_directory,
output_image_path,
tiles_shape,
tile_size=512,
top_n=5,
percentile=40.0
):
"""
Reconstructs the full image by stitching together processed tiles for a specific percentile.
"""
tiles_y, tiles_x = tiles_shape
full_height = tiles_y * tile_size
full_width = tiles_x * tile_size
full_image = Image.new('L', (full_width, full_height)) # Assuming grayscale output
logging.info(f"Creating a blank full image of size: {full_width}x{full_height} pixels")
for i in tqdm(range(tiles_y), desc=f"Reconstructing image for p{percentile}"):
for j in range(tiles_x):
# Construct the expected tile filename
tile_filename = f"processed_tile_{i}_{j}_top{top_n}_p{percentile}.png"
tile_path = os.path.join(processed_tiles_directory, tile_filename)
if not os.path.exists(tile_path):
logging.warning(f"Processed tile ({i}, {j}) not found at {tile_path}. Filling with black.")
# Create a black tile
tile_image = Image.new('L', (tile_size, tile_size), 0)
else:
try:
tile_image = Image.open(tile_path).convert('L')
logging.info(f"Loaded processed tile from {tile_path}")
except Exception as e:
logging.error(f"Failed to open processed tile {tile_path}: {e}")
# Create a black tile in case of error
tile_image = Image.new('L', (tile_size, tile_size), 0)
# Ensure tile size matches
if tile_image.size != (tile_size, tile_size):
logging.warning(f"Tile {tile_filename} has size {tile_image.size}, resizing to ({tile_size}, {tile_size})")
tile_image = tile_image.resize((tile_size, tile_size), Image.BICUBIC)
# Paste the tile into the full image
full_image.paste(tile_image, (j * tile_size, i * tile_size))
# Save the final reconstructed image
try:
full_image.save(output_image_path)
logging.info(f"Reconstructed full image saved to {output_image_path}")
except Exception as e:
logging.error(f"Failed to save full image to {output_image_path}: {e}")
# ---------------------------------------------------------------------------
# 6) Main Function
# ---------------------------------------------------------------------------
def main():
"""
Main function to handle argument parsing, processing, and reconstruction.
"""
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
parser = argparse.ArgumentParser(description="Percentile Stacking for Beeless Image Generation")
# Optional arguments for processing a single tile
parser.add_argument(
'--tile_i',
type=int,
help="Tile row index (e.g., 4)"
)
parser.add_argument(
'--tile_j',
type=int,
help="Tile column index (e.g., 8)"
)
# Other optional parameters
parser.add_argument(
'--encoder',
type=str,
default='vitl',
choices=model_configs.keys(),
help="Depth Anything V2 encoder model to use (default: 'vitl')"
)
parser.add_argument(
'--input_directory',
type=str,
default="percentile_stacking_22_12_2024_41images",
help="Directory containing input frames"
)
parser.add_argument(
'--processed_tiles_directory',
type=str,
default="good_depth_anything_large_filtered_stack_filtertop5_40thpercentile/processed_tiles",
help="Directory to store processed tiles"
)
parser.add_argument(
'--output_image_path',
type=str,
default="good_depth_anything_large_filtered_stack_filtertop5/output_images/clean_full_image_top5_p40.png",
help="Base path to save the final reconstructed images (must include file extension, e.g., '.png')"
)
parser.add_argument(
'--tile_size',
type=int,
default=512,
help="Size of each tile (default: 512)"
)
parser.add_argument(
'--top_n',
type=int,
default=5,
help="Number of top frames to consider based on depth (default: 5)"
)
parser.add_argument(
'--percentiles',
type=float,
nargs='+',
default=[40.0],
help="List of desired percentiles for stacking (default: [40.0])"
)
parser.add_argument(
'--input_size',
type=int,
default=518,
help="Input size for model inference (default: 518)"
)
# New arguments for saving cropped images and depth maps
parser.add_argument(
'--save_cropped',
action='store_true',
help="Flag to save cropped tile images"
)
parser.add_argument(
'--cropped_dir',
type=str,
default=None,
help="Directory to save cropped tile images (required if --save_cropped is set)"
)
parser.add_argument(
'--save_depth_maps',
action='store_true',
help="Flag to save depth maps as images"
)
parser.add_argument(
'--depth_maps_dir',
type=str,
default=None,
help="Directory to save depth map images (required if --save_depth_maps is set)"
)
args = parser.parse_args()
# Validate output_image_path has a valid extension
valid_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif']
base_output, ext = os.path.splitext(args.output_image_path)
if ext.lower() not in valid_extensions:
logging.error(f"Invalid or missing file extension in output_image_path: '{args.output_image_path}'. "
f"Please use one of the following extensions: {', '.join(valid_extensions)}")
sys.exit(1)
# Validate arguments related to saving cropped images and depth maps
if args.save_cropped and not args.cropped_dir:
parser.error("--save_cropped requires --cropped_dir to be specified.")
if args.save_depth_maps and not args.depth_maps_dir:
parser.error("--save_depth_maps requires --depth_maps_dir to be specified.")
# Create directories if they don't exist
os.makedirs(args.processed_tiles_directory, exist_ok=True)
os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
if args.save_cropped:
os.makedirs(args.cropped_dir, exist_ok=True)
if args.save_depth_maps:
os.makedirs(args.depth_maps_dir, exist_ok=True)
# Validate percentiles
for percentile in args.percentiles:
if not (0 <= percentile <= 100):
logging.error(f"Percentile '{percentile}' is out of bounds. Must be between 0 and 100.")
raise ValueError("Percentiles must be between 0 and 100.")
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"Using GPU: {gpu_name}")
else:
print("No GPU available.")
# Load model
try:
model, device = load_depth_anything_model(encoder=args.encoder)
except Exception as e:
logging.critical(f"Failed to load model: {e}")
sys.exit(1)
# Determine image dims & tile grid from first valid image
sample_image_path = None
for fname in sorted(os.listdir(args.input_directory)):
if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
sample_image_path = os.path.join(args.input_directory, fname)
break
if not sample_image_path:
logging.critical(f"No valid images found in {args.input_directory}")
sys.exit(1)
try:
sample_img = Image.open(sample_image_path).convert('RGB')
img_width, img_height = sample_img.size
except Exception as e:
logging.critical(f"Failed to open sample image {sample_image_path}: {e}")
sys.exit(1)
tiles_x = img_width // args.tile_size
tiles_y = img_height // args.tile_size
tiles_shape = (tiles_y, tiles_x)
logging.info(f"Image size: {img_width} x {img_height}")
logging.info(f"Tiles: {tiles_y} rows x {tiles_x} columns, tile_size={args.tile_size}")
logging.info(f"Model input size: {args.input_size}")
logging.info(f"Percentiles to process: {args.percentiles}")
# Phase 1: Process tiles
if args.tile_i is not None and args.tile_j is not None:
# Single Tile Mode
i = args.tile_i
j = args.tile_j
if i < 0 or i >= tiles_y or j < 0 or j >= tiles_x:
logging.error(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}")
raise ValueError(f"Tile indices ({i}, {j}) are out of bounds. Valid ranges: 0 <= i < {tiles_y}, 0 <= j < {tiles_x}")
logging.info(f"Processing single tile ({i}, {j}) with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...")
process_tile(
tile_position=(i, j),
input_directory=args.input_directory,
processed_tiles_directory=args.processed_tiles_directory,
model=model,
device=device,
tile_size=args.tile_size,
top_n=args.top_n,
percentiles=args.percentiles,
input_size=args.input_size,
save_cropped=args.save_cropped,
cropped_dir=args.cropped_dir,
save_depth_maps=args.save_depth_maps,
depth_maps_dir=args.depth_maps_dir
)
logging.info(f"Tile ({i}, {j}) processed successfully.")
else:
# Batch Mode: Process all tiles
logging.info(f"Processing all tiles with top_n={args.top_n}, percentiles={args.percentiles}, and input_size={args.input_size}...")
for i in tqdm(range(tiles_y), desc="Processing tiles row-wise"):
for j in range(tiles_x):
process_tile(
tile_position=(i, j),
input_directory=args.input_directory,
processed_tiles_directory=args.processed_tiles_directory,
model=model,
device=device,
tile_size=args.tile_size,
top_n=args.top_n,
percentiles=args.percentiles,
input_size=args.input_size,
save_cropped=args.save_cropped,
cropped_dir=args.cropped_dir,
save_depth_maps=args.save_depth_maps,
depth_maps_dir=args.depth_maps_dir
)
logging.info("All tiles processed.")
# Phase 2: Reconstruct full images for each percentile
for percentile in args.percentiles:
# Define output image path for the current percentile
base_output, ext = os.path.splitext(args.output_image_path)
percentile_suffix = f"_p{percentile}"
output_image_path = f"{base_output}{percentile_suffix}{ext}"
logging.info(f"Reconstructing full image for percentile {percentile}...")
reconstruct_full_image(
processed_tiles_directory=args.processed_tiles_directory,
output_image_path=output_image_path,
tiles_shape=tiles_shape,
tile_size=args.tile_size,
top_n=args.top_n,
percentile=percentile
)
logging.info(f"Reconstruction for percentile {percentile} complete. Saved to {output_image_path}")
logging.info("All processes completed successfully.")
if __name__ == "__main__":
main()
Editor is loading...
Leave a Comment