Untitled
import os import sys import argparse import yaml import torch import cv2 import numpy as np import logging import rosbag import shutil import re import traceback import math from Warp.Codes.network import ( get_stitched_result as warp_get_stitched_result, Network as WarpNetwork, build_new_ft_model as warp_build_new_ft_model, ) from Composition.Codes.network import ( build_model as composition_build_model, Network as CompositionNetwork, ) import torchvision.transforms as T from zz_bagparser_w import ( load_xy_positions_specific, process_images, cleanup_redundant_images, crop_images_specific, crop_images_specific_vertical ) # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("stitching_process.log"), logging.StreamHandler(sys.stdout) ] ) # GPU Configuration GPU_DEVICE = "0" COMPOSITION_MODEL_PATH = os.path.join("Composition", "model", "epoch050_model.pth") def load_config(config_path): """ Load YAML configuration file. Parameters: config_path (str): Path to the YAML config file. Returns: dict: Configuration parameters. """ with open(config_path, 'r') as file: config = yaml.safe_load(file) return config # Stitching Functions def perform_manual_warping(cropped_img1_path, cropped_img2_path, output_dir): """ Manually perform warping by creating masks and blending the cropped images horizontally. Parameters: cropped_img1_path (str): Path to the cropped left image. cropped_img2_path (str): Path to the cropped right image. output_dir (str): Directory to save warped images and masks. """ img1 = cv2.imread(cropped_img1_path) img2 = cv2.imread(cropped_img2_path) if img1 is None: raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}") if img2 is None: raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}") # Generate mask1 height1, width1, _ = img1.shape split1 = int(width1 * 0.9) mask1 = np.zeros_like(img1, dtype=np.uint8) mask1[:, :split1, :] = 255 mask1[:, split1:, :] = 0 # Generate mask2 height2, width2, _ = img2.shape split2 = int(width2 * 0.1) mask2 = np.zeros_like(img2, dtype=np.uint8) mask2[:, :split2, :] = 0 mask2[:, split2:, :] = 255 # Create warped images warp1 = cv2.bitwise_and(img1, mask1) warp2 = cv2.bitwise_and(img2, mask2) # Save masks and warped images os.makedirs(output_dir, exist_ok=True) cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1) cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2) cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1) cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2) logging.info(f"Manual warped images and masks saved in {output_dir}") def perform_manual_warping_vertical(cropped_img1_path, cropped_img2_path, output_dir): """ Manually perform warping by creating masks and blending the cropped images vertically. Parameters: cropped_img1_path (str): Path to the cropped bottom image. cropped_img2_path (str): Path to the cropped top image. output_dir (str): Directory to save warped images and masks. """ img1 = cv2.imread(cropped_img1_path) img2 = cv2.imread(cropped_img2_path) if img1 is None: raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}") if img2 is None: raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}") # Generate mask1 height1, width1, _ = img1.shape split1 = int(height1 * 0.9) mask1 = np.zeros_like(img1, dtype=np.uint8) mask1[:split1, :, :] = 255 mask1[split1:, :, :] = 0 # Generate mask2 height2, width2, _ = img2.shape split2 = int(height2 * 0.1) mask2 = np.zeros_like(img2, dtype=np.uint8) mask2[:split2, :, :] = 0 mask2[split2:, :, :] = 255 # Create warped images warp1 = cv2.bitwise_and(img1, mask1) warp2 = cv2.bitwise_and(img2, mask2) # Save masks and warped images os.makedirs(output_dir, exist_ok=True) cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1) cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2) cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1) cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2) logging.info(f"Manual warped images and masks saved in {output_dir}") def perform_composition(net, input_dir, output_dir, config_options): """ Perform composition on the warped outputs to generate the stitched image. Parameters: net (CompositionNetwork): The initialized composition network. input_dir (str): Directory containing warped images and masks. output_dir (str): Directory to save the composed image. config_options (dict): Configuration options from config.yaml. """ # Load warped images warp1 = cv2.imread(os.path.join(input_dir, "warp1.jpg")) warp2 = cv2.imread(os.path.join(input_dir, "warp2.jpg")) if warp1 is None or warp2 is None: raise FileNotFoundError("Warped images not found in the input directory.") # Convert to float32 and normalize warp1 = warp1.astype(dtype=np.float32) warp1 = (warp1 / 127.5) - 1.0 warp1 = np.transpose(warp1, [2, 0, 1]) warp1_tensor = torch.tensor(warp1).unsqueeze(0) warp2 = warp2.astype(dtype=np.float32) warp2 = (warp2 / 127.5) - 1.0 warp2 = np.transpose(warp2, [2, 0, 1]) warp2_tensor = torch.tensor(warp2).unsqueeze(0) # Load masks mask1 = cv2.imread(os.path.join(input_dir, "mask1.jpg")) mask2 = cv2.imread(os.path.join(input_dir, "mask2.jpg")) if mask1 is None or mask2 is None: raise FileNotFoundError("Masks not found in the input directory.") mask1 = mask1.astype(dtype=np.float32) / 255.0 mask1 = np.transpose(mask1, [2, 0, 1]) mask1_tensor = torch.tensor(mask1).unsqueeze(0) mask2 = mask2.astype(dtype=np.float32) / 255.0 mask2 = np.transpose(mask2, [2, 0, 1]) mask2_tensor = torch.tensor(mask2).unsqueeze(0) if torch.cuda.is_available(): warp1_tensor = warp1_tensor.cuda() warp2_tensor = warp2_tensor.cuda() mask1_tensor = mask1_tensor.cuda() mask2_tensor = mask2_tensor.cuda() net = net.cuda() # Perform composition net.eval() with torch.no_grad(): batch_out = composition_build_model( net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor ) stitched_image = batch_out["stitched_image"] logging.info(f"Stitched image shape: {stitched_image.shape}") # Save the composed image os.makedirs(output_dir, exist_ok=True) stitched_image_np = ( (stitched_image[0] + 1) * 127.5 ).cpu().numpy().transpose(1, 2, 0) cv2.imwrite( os.path.join(output_dir, "composition.jpg"), stitched_image_np.astype(np.uint8) ) logging.info(f"Composed image saved in {output_dir}") # Save learned masks if enabled if config_options.get('save_learned_masks', False): learned_mask1 = (batch_out['learned_mask1'][0] * 255).cpu().numpy().transpose(1, 2, 0) learned_mask2 = (batch_out['learned_mask2'][0] * 255).cpu().numpy().transpose(1, 2, 0) cv2.imwrite(os.path.join(input_dir, "learn_mask1.jpg"), learned_mask1.astype(np.uint8)) cv2.imwrite(os.path.join(input_dir, "learn_mask2.jpg"), learned_mask2.astype(np.uint8)) logging.info(f"Learned masks saved in {input_dir}/learn_mask1.jpg and {input_dir}/learn_mask2.jpg") # Save composition_color.jpg if enabled if config_options.get('save_composition_color', False): s1 = ((warp1_tensor[0] + 1) * 127.5 * batch_out['learned_mask1'][0]).cpu().detach().numpy().transpose(1, 2, 0) s2 = ((warp2_tensor[0] + 1) * 127.5 * batch_out['learned_mask2'][0]).cpu().detach().numpy().transpose(1, 2, 0) fusion = np.zeros((warp1_tensor.shape[2], warp1_tensor.shape[3], 3), np.uint8) fusion[..., 0] = s2[..., 0] fusion[..., 1] = (s1[..., 1] * 0.5 + s2[..., 1] * 0.5).astype(np.uint8) fusion[..., 2] = s1[..., 2] composition_color_path = os.path.join(output_dir, "composition_color.jpg") cv2.imwrite(composition_color_path, fusion) logging.info(f"Composition color image saved at {composition_color_path}") def stitch_pair(left_image, right_image, tx_multiplier, warp_net, composition_net, output_dir, level, group_num, pair_num, config_options): """ Stitch a pair of images horizontally using manual warping and composition. Parameters: left_image (str): Path to the left image. right_image (str): Path to the right image. tx_multiplier (int): Multiplier for the base translation in the x-direction. warp_net (WarpNetwork): Initialized warping network (Not used in manual warping). composition_net (CompositionNetwork): Initialized composition network. output_dir (str): Directory to save intermediate and final outputs. level (int): Current stitching level. group_num (str): Group number for unique identification. pair_num (int): Pair number within the group. config_options (dict): Configuration options from config.yaml. Returns: str: Path to the stitched image. """ base_tx = config_options['tx_level'] tx = tx_multiplier * base_tx # Logging the images and tx value logging.info(f"Stitching pair {pair_num} at level {level}, group {group_num}") logging.info(f"Left Image: {left_image}") logging.info(f"Right Image: {right_image}") logging.info(f"Effective tx used: {tx}") # Perform cropping cropped = crop_images_specific(left_image, right_image, tx, 0) cropped_img1 = cropped["cropped_img1"] cropped_img2 = cropped["cropped_img2"] unused_img1 = cropped["unused_img1"] unused_img2 = cropped["unused_img2"] # Save cropped images for reference cropped_dir = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}", "cropped") os.makedirs(cropped_dir, exist_ok=True) cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1) cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2) cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1) cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2) logging.info(f"Cropped images saved for level {level} group {group_num} pair {pair_num} in {cropped_dir}") # Paths to the cropped images cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg") cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg") # Perform manual warping on cropped regions perform_manual_warping( cropped_img1_path, cropped_img2_path, cropped_dir ) # Perform composition on warped outputs perform_composition(composition_net, cropped_dir, cropped_dir, config_options) # Load the composed overlapping region composed_overlap_path = os.path.join(cropped_dir, "composition.jpg") composed_overlap = cv2.imread(composed_overlap_path) if composed_overlap is None: logging.error(f"Composition failed for level {level} group {group_num} pair {pair_num}. Composed image not found at {composed_overlap_path}") return None # Concatenate unused parts with composed overlap horizontally try: final_image = np.hstack((unused_img1, composed_overlap, unused_img2)) except Exception as e: logging.error(f"Error during image concatenation for level {level} group {group_num} pair {pair_num}: {e}") return None logging.info(f"Final stitched image shape for level {level} group {group_num} pair {pair_num}: {final_image.shape}") # Save the final stitched image stitched_image_path = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}_final.jpg") os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True) cv2.imwrite(stitched_image_path, final_image) logging.info(f'Final stitched image saved for level {level} group {group_num} pair {pair_num} at {stitched_image_path}') return stitched_image_path def stitch_pair_vertical(bottom_image, top_image, ty_multiplier, warp_net, composition_net, output_dir, level, pair_num, config_options): """ Stitch a pair of images vertically using manual warping and composition. Parameters: bottom_image (str): Path to the bottom image. top_image (str): Path to the top image. ty_multiplier (int): Multiplier for the base translation in the y-direction. warp_net (WarpNetwork): Initialized warping network (Not used in manual warping). composition_net (CompositionNetwork): Initialized composition network. output_dir (str): Directory to save intermediate and final outputs. level (int): Current stitching level. pair_num (int): Pair number within the current level. config_options (dict): Configuration options from config.yaml. Returns: str: Path to the stitched image. """ base_ty = config_options['ty_level'] ty = ty_multiplier * base_ty # Logging the images and ty value logging.info(f"Stitching pair {pair_num} at level {level}") logging.info(f"Bottom Image: {bottom_image}") logging.info(f"Top Image: {top_image}") logging.info(f"Effective ty used: {ty}") # Perform cropping cropped = crop_images_specific_vertical(bottom_image, top_image, ty) cropped_img1 = cropped["cropped_img1"] # Bottom cropped cropped_img2 = cropped["cropped_img2"] # Top cropped unused_img1 = cropped["unused_img1"] # Bottom unused unused_img2 = cropped["unused_img2"] # Top unused # Save cropped images for reference cropped_dir = os.path.join(output_dir, f"level{level}_pair{pair_num}", "cropped") os.makedirs(cropped_dir, exist_ok=True) cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1) cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2) cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1) cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2) logging.info(f"Cropped images saved for level {level} pair {pair_num} in {cropped_dir}") # Paths to the cropped images cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg") cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg") # Perform manual warping on cropped regions perform_manual_warping_vertical( cropped_img1_path, cropped_img2_path, cropped_dir ) # Perform composition on warped outputs perform_composition(composition_net, cropped_dir, cropped_dir, config_options) # Load the composed overlapping region composed_overlap_path = os.path.join(cropped_dir, "composition.jpg") composed_overlap = cv2.imread(composed_overlap_path) if composed_overlap is None: logging.error(f"Composition failed for level {level} pair {pair_num}. Composed image not found at {composed_overlap_path}") return None # Concatenate unused parts with composed overlap vertically try: # Assuming all images have the same width after cropping final_image = np.vstack((unused_img1, composed_overlap, unused_img2)) except Exception as e: logging.error(f"Error during image concatenation for level {level} pair {pair_num}: {e}") return None logging.info(f"Final stitched image shape for level {level} pair {pair_num}: {final_image.shape}") # Save the final stitched image stitched_image_path = os.path.join(output_dir, f"level{level}_pair{pair_num}_final.jpg") os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True) cv2.imwrite(stitched_image_path, final_image) logging.info(f'Final stitched image saved for level {level} pair {pair_num} at {stitched_image_path}') return stitched_image_path def tree_stitch_images( images, stride_multiplier, stitch_func, warp_net, composition_net, output_dir, level, group_num, config_options, direction='horizontal' ): """ Perform tree-like stitching on a list of images until only one image remains. Parameters: images (list): List of image paths to stitch. stride_multiplier (int): Stride multiplier based on the configuration (e.g., row_stride). stitch_func (function): Stitching function to use (stitch_pair or stitch_pair_vertical). warp_net (WarpNetwork): Warp network instance. composition_net (CompositionNetwork): Composition network instance. output_dir (str): Directory to save outputs. level (int): Starting stitching level. group_num (str): Group identifier. config_options (dict): Configuration options. direction (str): 'horizontal' or 'vertical'. Returns: str: Path to the final stitched image. """ images_to_stitch = images.copy() current_level = 0 # Calculate the largest power of two less than or equal to the number of images num_images = len(images_to_stitch) if num_images < 2: logging.warning("Not enough images to stitch.") return images_to_stitch[0] if images_to_stitch else None largest_power = 2**int(math.log2(num_images)) selected_images = images_to_stitch[:largest_power] omitted_images = images_to_stitch[largest_power:] if omitted_images: logging.info(f"Omitting {len(omitted_images)} redundant image(s) to make the count a power of two.") images_to_stitch = selected_images while len(images_to_stitch) > 1: stitched_images = [] # Calculate multiplier based on stride and current level multiplier = stride_multiplier * (2 ** current_level) logging.info(f"Stitching Level {level + current_level}: Multiplier = {multiplier}") for i in range(0, len(images_to_stitch), 2): img1 = images_to_stitch[i] if i + 1 < len(images_to_stitch): img2 = images_to_stitch[i + 1] if direction == 'horizontal': stitched_image = stitch_func( left_image=img1, right_image=img2, tx_multiplier=multiplier, warp_net=warp_net, composition_net=composition_net, output_dir=output_dir, level=level + current_level, group_num=group_num, pair_num=(i // 2) + 1, config_options=config_options ) else: stitched_image = stitch_func( bottom_image=img1, top_image=img2, ty_multiplier=multiplier, warp_net=warp_net, composition_net=composition_net, output_dir=output_dir, level=level + current_level, pair_num=(i // 2) + 1, config_options=config_options ) if stitched_image: stitched_images.append(stitched_image) else: logging.error(f"Failed to stitch {img1} and {img2}") return None else: # If odd number of images, carry the last one forward stitched_images.append(img1) logging.info(f"Odd image {img1} carried forward to the next level.") images_to_stitch = stitched_images current_level += 1 logging.info(f"Completed stitching up to level {level + current_level - 1}") return images_to_stitch[0] if images_to_stitch else None def select_power_of_two_rows(rows_to_stitch): """ Selects the largest power-of-two number of rows from the given list. Parameters: rows_to_stitch (list): List of row indices to stitch. Returns: list: Subset of rows_to_stitch with length as a power of two. """ num_rows = len(rows_to_stitch) if num_rows == 0: return [] largest_power = 2**int(math.log2(num_rows)) selected_rows = rows_to_stitch[:largest_power] omitted_rows = rows_to_stitch[largest_power:] if omitted_rows: logging.info(f"Omitting {len(omitted_rows)} redundant row(s) to make the count a power of two: {selected_rows}") return selected_rows def sanitize_filename(filename): """ Sanitize the filename by replacing or removing invalid characters. """ # Replace spaces with underscores filename = filename.replace(" ", "_") # Replace colons with hyphens filename = filename.replace(":", "-") # Remove any other invalid characters (for cross-platform compatibility) # Allow only alphanumerics, underscores, hyphens, and dots filename = re.sub(r'[^A-Za-z0-9._-]', '', filename) return filename def cleanup_unused_resources(scan_path, used_horizontal_scans, used_images_per_scan): """ Delete unused horizontal_scanX directories and unused images within used directories. Parameters: scan_path (str): Path to the current scan directory. used_horizontal_scans (set): Set of used horizontal_scanX directory names. used_images_per_scan (dict): Dictionary mapping horizontal_scanX to set of used image filenames. """ # List all horizontal_scanX directories all_horizontal_scans = [ d for d in os.listdir(scan_path) if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d)) ] for hs_dir in all_horizontal_scans: hs_path = os.path.join(scan_path, hs_dir) if hs_dir not in used_horizontal_scans: # Delete the entire directory try: shutil.rmtree(hs_path) logging.info(f"Deleted unused directory: {hs_path}") except Exception as e: logging.error(f"Failed to delete directory {hs_path}: {e}") logging.error(traceback.format_exc()) else: # Delete unused images within the used horizontal_scanX directory used_images = used_images_per_scan.get(hs_dir, set()) all_images = [ f for f in os.listdir(hs_path) if f.lower().endswith('.jpg') ] for img in all_images: if img not in used_images: img_path = os.path.join(hs_path, img) try: os.remove(img_path) logging.info(f"Deleted unused image: {img_path}") except Exception as e: logging.error(f"Failed to delete image {img_path}: {e}") logging.error(traceback.format_exc()) def cleanup_intermediate_outputs(scan_path): """ Delete directories and files starting with 'level' (e.g., 'level1', 'level2', etc.) within the scan directory. Parameters: scan_path (str): Path to the current scan directory. """ # List all items in the scan directory for item in os.listdir(scan_path): item_path = os.path.join(scan_path, item) # Check if the item is a directory starting with 'level' if os.path.isdir(item_path) and re.match(r'^level\d+', item): try: shutil.rmtree(item_path) logging.info(f"Deleted intermediate directory: {item_path}") except Exception as e: logging.error(f"Failed to delete intermediate directory {item_path}: {e}") logging.error(traceback.format_exc()) # Check if the item is a file starting with 'level' elif os.path.isfile(item_path) and re.match(r'^level\d+', item): try: os.remove(item_path) logging.info(f"Deleted intermediate file: {item_path}") except Exception as e: logging.error(f"Failed to delete intermediate file {item_path}: {e}") logging.error(traceback.format_exc()) def main(): parser = argparse.ArgumentParser(description="Image Stitching Script") parser.add_argument( '--bag_path', type=str, required=True, help='Path to the ROS bag file.' ) parser.add_argument( '--output_dir', type=str, required=True, help='Directory to save the output images.' ) parser.add_argument( '--config_paths', type=str, required=True, nargs='+', help='Paths to the YAML configuration files for each hive and XY table.' ) args = parser.parse_args() # Convert to absolute paths BAG_PATH = os.path.abspath(args.bag_path) OUTPUT_BASE_DIR = os.path.abspath(args.output_dir) CONFIG_PATHS = [os.path.abspath(path) for path in args.config_paths] # Define 'Final_Outputs' directory final_outputs_dir = os.path.join(OUTPUT_BASE_DIR, "Final_Outputs") os.makedirs(final_outputs_dir, exist_ok=True) logging.info(f"Final outputs will be saved in {final_outputs_dir}") # Load and validate each configuration configs = [] for config_path in CONFIG_PATHS: if not os.path.exists(config_path): logging.error(f"Configuration file not found at {config_path}") sys.exit(1) config = load_config(config_path) # Validate required fields if 'hive_id' not in config or 'xy_id' not in config: logging.error(f"Configuration file {config_path} missing 'hive_id' or 'xy_id'.") sys.exit(1) configs.append(config) logging.info(f"Loaded configuration from {config_path}") try: # Environment and Network Initialization # Set CUDA environment variables os.environ["CUDA_DEVICES_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = GPU_DEVICE # Add project root to sys.path project_root = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, project_root) logging.info("Initializing WarpNetwork and CompositionNetwork...") # Initialize networks warp_net = WarpNetwork() composition_net = CompositionNetwork() logging.info("Networks initialized successfully.") # Load pre-trained models for composition_net if os.path.exists(COMPOSITION_MODEL_PATH): logging.info(f"Loading composition model from {COMPOSITION_MODEL_PATH}...") checkpoint = torch.load( COMPOSITION_MODEL_PATH, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu') ) composition_net.load_state_dict(checkpoint["model"]) logging.info(f"Loaded composition model from {COMPOSITION_MODEL_PATH}") else: logging.error(f"No composition model found at {COMPOSITION_MODEL_PATH}") sys.exit(1) # Move networks to GPU if available if torch.cuda.is_available(): warp_net = warp_net.cuda() composition_net = composition_net.cuda() logging.info("Networks moved to GPU.") else: logging.info("CUDA not available. Networks are on CPU.") # Bag Parsing and Image Processing logging.info("\n========================") logging.info("Starting Bag Parsing and Processing") logging.info("========================\n") for config in configs: hive_id = config.get('hive_id') xy_id = config.get('xy_id') logging.info(f"\n=== Processing Hive {hive_id}, XY Table {xy_id} ===") # Define paths specific to hive and XY table hive_xy_output_dir = os.path.join(OUTPUT_BASE_DIR, f"hive_{hive_id}", f"xy_{xy_id}") os.makedirs(hive_xy_output_dir, exist_ok=True) logging.info(f"Output directory for Hive {hive_id}, XY {xy_id}: {hive_xy_output_dir}") # Extract configurations from each config file stitching_mode = config.get('stitching', {}).get('mode', 'linear').lower() translations = config.get('translations', {}) tx_level = translations.get('tx_level', -320) ty_level = translations.get('ty_level', -330) config_options = config.get('options', {}) config_options['tx_level'] = tx_level config_options['ty_level'] = ty_level # Retrieve 'resize_shape' from config resize_shape = config_options.get('resize_shape', None) if resize_shape is not None: if isinstance(resize_shape, list) and len(resize_shape) == 2: resize_shape = tuple(resize_shape) logging.info(f"Resize shape set to: {resize_shape}") else: logging.warning(f"Invalid resize_shape format in config: {resize_shape}. Setting to None.") resize_shape = None config_options['resize_shape'] = resize_shape # Extract max_images_to_process from stitching MAX_IMAGES_TO_PROCESS = config.get('stitching', {}).get('max_images_to_process', 120) if stitching_mode == 'linear': linear_params = config.get('stitching', {}).get('linear', {}) COLUMNS_TO_STITCH = linear_params.get('columns_to_stitch', []) ROWS_TO_STITCH = linear_params.get('rows_to_stitch', []) elif stitching_mode == 'tree': tree_params = config.get('stitching', {}).get('tree', {}) column_params = tree_params.get('columns', {}) row_params = tree_params.get('rows', {}) else: logging.error(f"Invalid stitching mode: {stitching_mode}. Choose 'linear' or 'tree'.") continue # Skip this config and proceed to the next # Load XY positions for this hive and XY table xy_positions = load_xy_positions_specific(BAG_PATH, hive_id, xy_id) if not xy_positions: logging.warning(f"No XY positions found for Hive {hive_id}, XY {xy_id}. Skipping.") continue # Process images from the bag for this hive and XY table process_images( xy_positions, BAG_PATH, output_dir=hive_xy_output_dir, threshold=0.001, significant_move=0.1, resize_shape=config_options.get('resize_shape'), # Pass from config max_images=MAX_IMAGES_TO_PROCESS, apply_color_filter=config_options.get('apply_color_filter', False), compressed_topic=f"/hive_{hive_id}/xy_{xy_id}/bee_camera/compressed", xy_position_topic=f"/hive_{hive_id}/xy_{xy_id}/xy_position" ) logging.info("\n========================") logging.info("Starting Cleanup Phase") logging.info("========================\n") cleanup_redundant_images(hive_xy_output_dir) logging.info(f"\nAll images processed and saved in the directory: {hive_xy_output_dir}") logging.info("\n========================") logging.info("Starting Stitching Phase") logging.info("========================\n") # Iterate over each scan directory within the hive and XY table for scan_folder in os.listdir(hive_xy_output_dir): scan_path = os.path.join(hive_xy_output_dir, scan_folder) if not os.path.isdir(scan_path): logging.info(f"Skipping non-directory: {scan_path}") continue try: logging.info(f"\n--- Stitching Scan: {scan_folder} ---") if stitching_mode == 'linear': # =============================== # Linear Stitching Mode # =============================== logging.info("Mode: Linear Stitching") logging.info(f"Linear Stitching Mode: Columns to stitch: {COLUMNS_TO_STITCH}") logging.info(f"Linear Stitching Mode: Rows to stitch: {ROWS_TO_STITCH}") # Find all horizontal_scanX subfolders horizontal_scan_dirs = sorted([ d for d in os.listdir(scan_path) if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d)) ], key=lambda x: int(x.replace("horizontal_scan", ""))) # Sort numerically if not horizontal_scan_dirs: logging.info(f"No 'horizontal_scanX' folders found in {scan_path}. Skipping stitching.") continue # Select only the specified rows for linear stitching selected_horizontal_scan_dirs = [ d for d in horizontal_scan_dirs if int(d.replace("horizontal_scan", "")) in ROWS_TO_STITCH ] if not selected_horizontal_scan_dirs: logging.warning(f"No 'horizontal_scanX' folders match ROWS_TO_STITCH {ROWS_TO_STITCH} in {scan_path}. Skipping stitching.") continue horizontal_stitched_paths = [] # Dictionaries to track used horizontal_scanX directories and their used images used_horizontal_scans = set() used_images_per_scan = {} for horizontal_scan_dir in selected_horizontal_scan_dirs: images_dir = os.path.join(scan_path, horizontal_scan_dir) scan_images = sorted([ f for f in os.listdir(images_dir) if f.lower().endswith('.jpg') ], key=lambda x: int(x.split('_')[1].split('.')[0])) # Sort by image number logging.info(f"Processing {horizontal_scan_dir} with {len(scan_images)} images.") if len(scan_images) < 2: logging.warning(f"Warning: '{horizontal_scan_dir}' does not contain enough images for stitching. Skipping this horizontal scan.") continue # Group images based on COLUMNS_TO_STITCH try: selected_columns = COLUMNS_TO_STITCH # Validate column indices for col in selected_columns: if col < 1 or col > len(scan_images): raise IndexError(f"Column index {col} is out of range for the current horizontal scan with {len(scan_images)} images.") # Sort the selected columns to ensure correct order selected_columns_sorted = sorted(selected_columns) selected_images_sorted = [os.path.join(images_dir, scan_images[col-1]) for col in selected_columns_sorted] except Exception as e: logging.error(f"Error selecting columns to stitch in '{horizontal_scan_dir}': {e}") continue # Track used horizontal_scanX directories and images hs_dir_name = horizontal_scan_dir used_horizontal_scans.add(hs_dir_name) if hs_dir_name not in used_images_per_scan: used_images_per_scan[hs_dir_name] = set() for img_path in selected_images_sorted: img_filename = os.path.basename(img_path) used_images_per_scan[hs_dir_name].add(img_filename) # Stitch selected images sequentially with dynamic tx_multiplier current_stitched = selected_images_sorted[0] for idx in range(1, len(selected_images_sorted)): previous_column = selected_columns_sorted[idx-1] current_column = selected_columns_sorted[idx] difference = current_column - previous_column # Calculate column difference tx_multiplier = difference # Set tx_multiplier based on column difference pair_num = idx # Pair number within the group logging.info(f"Stitching pair {pair_num}: Column {previous_column} and Column {current_column} with tx_multiplier={tx_multiplier}") stitched_image = stitch_pair( left_image=current_stitched, right_image=selected_images_sorted[idx], tx_multiplier=tx_multiplier, # Dynamic multiplier based on column difference warp_net=warp_net, composition_net=composition_net, output_dir=scan_path, level=1, # Level 1 for horizontal stitching group_num=horizontal_scan_dir.replace("horizontal_scan", ""), # Numeric identifier pair_num=pair_num, # Pair number within the group config_options=config_options # Updated config_options with tx_level and ty_level ) if stitched_image: current_stitched = stitched_image logging.info(f"Updated current_stitched to: {current_stitched}") else: logging.error(f"Failed to stitch image pair {current_stitched} and {selected_images_sorted[idx]} in '{horizontal_scan_dir}'.") break if stitched_image: horizontal_stitched_paths.append(stitched_image) else: logging.error(f"Horizontal stitching failed for '{horizontal_scan_dir}' in scan: {scan_folder}") if not horizontal_stitched_paths: logging.error(f"No horizontal stitching results available for scan: {scan_folder}. Skipping vertical stitching.") continue logging.info(f"Starting vertical stitching for scan: {scan_folder}") # Initialize current_panoramic with the first stitched image current_panoramic = horizontal_stitched_paths[0] logging.info(f"Initialized current_panoramic with: {current_panoramic}") # Iterate over the remaining stitched images for i in range(1, len(horizontal_stitched_paths)): previous_row = ROWS_TO_STITCH[i-1] current_row = ROWS_TO_STITCH[i] row_diff = current_row - ROWS_TO_STITCH[0] # Maintain original row_diff calculation ty_multiplier = row_diff # Set ty_multiplier based on row difference pair_num = i # Pair number within vertical stitching logging.info(f"Stitching vertically: {current_panoramic} with {horizontal_stitched_paths[i]} (ty_multiplier={ty_multiplier})") stitched_image = stitch_pair_vertical( bottom_image=current_panoramic, top_image=horizontal_stitched_paths[i], ty_multiplier=ty_multiplier, # Dynamic multiplier based on row difference warp_net=warp_net, composition_net=composition_net, output_dir=scan_path, level=2, # Level 2 for vertical stitching pair_num=pair_num, # Pair number within vertical stitching config_options=config_options ) if stitched_image: current_panoramic = stitched_image # Update for cumulative stitching logging.info(f"Updated current_panoramic to: {current_panoramic}") else: logging.error(f"Failed to vertically stitch image pair {current_panoramic} and {horizontal_stitched_paths[i]} in scan: {scan_folder}.") break # Save the final panoramic image final_panoramic_stitched = current_panoramic logging.info(f"\n--- Stitching Completed for Scan: {scan_folder} ---") logging.info(f"Final vertically stitched (panoramic) image saved at: {final_panoramic_stitched}") # **Save the Final Stitched Image to Final_Outputs** scan_folder_name = scan_folder # e.g., "2024-11-23 21-48-23_298" sanitized_scan_folder_name = sanitize_filename(scan_folder_name) new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg" target_path = os.path.join(final_outputs_dir, new_filename) try: shutil.copy(final_panoramic_stitched, target_path) logging.info(f"Final stitched image copied to {target_path}") except Exception as e: logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}") logging.error(traceback.format_exc()) # **Cleanup Unused Resources** cleanup_option = config_options.get('cleanup_unused', False) if cleanup_option: logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}") cleanup_unused_resources( scan_path=scan_path, used_horizontal_scans=used_horizontal_scans, used_images_per_scan=used_images_per_scan ) logging.info(f"Cleanup completed for scan: {scan_folder}") else: logging.info("Cleanup of unused resources is disabled in the configuration.") # **Cleanup Intermediate Outputs** cleanup_intermediate_option = config_options.get('cleanup_intermediate', False) if cleanup_intermediate_option: logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}") cleanup_intermediate_outputs(scan_path=scan_path) logging.info(f"Intermediate cleanup completed for scan: {scan_folder}") else: logging.info("Cleanup of intermediate outputs is disabled in the configuration.") elif stitching_mode == 'tree': # =============================== # Tree-like Stitching Mode # =============================== logging.info("Mode: Tree-like Stitching") logging.info(f"Tree-like Stitching Mode: Columns Params: {column_params}") logging.info(f"Tree-like Stitching Mode: Rows Params: {row_params}") # Generate rows to stitch based on start, end, row_stride start_row = row_params.get('start', 1) end_row = row_params.get('end', len([ d for d in os.listdir(scan_path) if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d)) ])) row_stride = row_params.get('row_stride', 1) rows_to_stitch = list(range(start_row, end_row + 1, row_stride)) # Adjust to power-of-two number of rows rows_to_stitch_power_two = select_power_of_two_rows(rows_to_stitch) # Collect image paths based on columns and rows selected_images = [] # Dictionaries to track used horizontal_scanX directories and their used images used_horizontal_scans = set() used_images_per_scan = {} for row in rows_to_stitch_power_two: horizontal_scan_dir = f"horizontal_scan{row}" images_dir = os.path.join(scan_path, horizontal_scan_dir) if not os.path.isdir(images_dir): logging.warning(f"Directory {images_dir} does not exist. Skipping.") continue scan_images = sorted([ f for f in os.listdir(images_dir) if f.lower().endswith('.jpg') ], key=lambda x: int(x.split('_')[1].split('.')[0])) # Sort by image number # Select columns based on start, end, col_stride start_col = column_params.get('start', 1) end_col = column_params.get('end', len(scan_images)) col_stride = column_params.get('col_stride', 1) columns_to_stitch = list(range(start_col, end_col + 1, col_stride)) logging.info(f"Row {row}: Columns to stitch: {columns_to_stitch}") selected_columns_sorted = sorted(columns_to_stitch) selected_images_sorted = [ os.path.join(images_dir, scan_images[col-1]) for col in selected_columns_sorted if 1 <= col <= len(scan_images) ] if not selected_images_sorted: logging.warning(f"No valid columns found for row {row}. Skipping.") continue # Track used horizontal_scanX directories and images hs_dir_name = horizontal_scan_dir used_horizontal_scans.add(hs_dir_name) if hs_dir_name not in used_images_per_scan: used_images_per_scan[hs_dir_name] = set() for img_path in selected_images_sorted: img_filename = os.path.basename(img_path) used_images_per_scan[hs_dir_name].add(img_filename) # Perform tree-like stitching on selected images logging.info(f"Performing tree-like stitching on row {row} with {len(selected_images_sorted)} images.") stitched_row_image = tree_stitch_images( images=selected_images_sorted, stride_multiplier=col_stride, # Use col_stride as stride_multiplier stitch_func=stitch_pair, warp_net=warp_net, composition_net=composition_net, output_dir=scan_path, level=1, # Starting level for horizontal stitching group_num=f"row{row}", config_options=config_options, direction='horizontal' ) if stitched_row_image: selected_images.append(stitched_row_image) else: logging.error(f"Tree-like stitching failed for row {row}.") continue if len(selected_images) < 2: logging.warning(f"Not enough stitched rows to perform vertical stitching in scan {scan_folder}. Skipping.") continue # Perform tree-like vertical stitching on the stitched rows logging.info("Performing tree-like stitching vertically on stitched rows.") final_panoramic_image = tree_stitch_images( images=selected_images, stride_multiplier=row_stride, # Use row_stride as stride_multiplier stitch_func=stitch_pair_vertical, warp_net=warp_net, composition_net=composition_net, output_dir=scan_path, level=1, # Starting level for vertical stitching group_num="vertical", config_options=config_options, direction='vertical' ) if final_panoramic_image: logging.info(f"Final tree-like stitched panoramic image saved at: {final_panoramic_image}") # **Save the Final Stitched Image to Final_Outputs** scan_folder_name = scan_folder # e.g., "2024-11-23 21-48-23_298" sanitized_scan_folder_name = sanitize_filename(scan_folder_name) new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg" target_path = os.path.join(final_outputs_dir, new_filename) try: shutil.copy(final_panoramic_image, target_path) logging.info(f"Final stitched image copied to {target_path}") except Exception as e: logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}") logging.error(traceback.format_exc()) else: logging.error(f"Failed to perform tree-like vertical stitching in scan: {scan_folder}") # **Cleanup Unused Resources** cleanup_option = config_options.get('cleanup_unused', False) if cleanup_option: logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}") cleanup_unused_resources( scan_path=scan_path, used_horizontal_scans=used_horizontal_scans, used_images_per_scan=used_images_per_scan ) logging.info(f"Cleanup completed for scan: {scan_folder}") else: logging.info("Cleanup of unused resources is disabled in the configuration.") # **Cleanup Intermediate Outputs** cleanup_intermediate_option = config_options.get('cleanup_intermediate', False) if cleanup_intermediate_option: logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}") cleanup_intermediate_outputs(scan_path=scan_path) logging.info(f"Intermediate cleanup completed for scan: {scan_folder}") else: logging.info("Cleanup of intermediate outputs is disabled in the configuration.") else: logging.error(f"Unknown stitching mode: {stitching_mode}. Skipping scan: {scan_folder}") continue except Exception as scan_error: logging.error(f"An error occurred while processing scan '{scan_folder}': {scan_error}") logging.error(traceback.format_exc()) logging.info(f"Skipping scan '{scan_folder}' and continuing with the next one.") continue # Skip to the next scan # Cleanup and Exit logging.info("\n========================") logging.info("Stitching Process Completed") logging.info("========================\n") # Clean up PyTorch models and free CUDA memory if torch.cuda.is_available(): warp_net.cpu() composition_net.cpu() del warp_net del composition_net torch.cuda.empty_cache() logging.info("Cleaned up PyTorch models and cleared CUDA cache.") logging.info("Exiting the script.") sys.exit(0) # Ensure the script exits automatically except Exception as e: logging.error(f"An unexpected error occurred: {e}\n{traceback.format_exc()}") sys.exit(1) if __name__ == "__main__": main()
Leave a Comment