Untitled

 avatar
unknown
plain_text
14 days ago
56 kB
2
Indexable
import os
import sys
import argparse
import yaml
import torch
import cv2
import numpy as np
import logging
import rosbag
import shutil   
import re 
import traceback
import math
from Warp.Codes.network import (
    get_stitched_result as warp_get_stitched_result,
    Network as WarpNetwork,
    build_new_ft_model as warp_build_new_ft_model,
)
from Composition.Codes.network import (
    build_model as composition_build_model,
    Network as CompositionNetwork,
)
import torchvision.transforms as T

from zz_bagparser_w import (
    load_xy_positions_specific,
    process_images,
    cleanup_redundant_images,
    crop_images_specific,
    crop_images_specific_vertical
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("stitching_process.log"),
        logging.StreamHandler(sys.stdout)
    ]
)

# GPU Configuration
GPU_DEVICE = "0"

COMPOSITION_MODEL_PATH = os.path.join("Composition", "model", "epoch050_model.pth")


def load_config(config_path):
    """
    Load YAML configuration file.

    Parameters:
        config_path (str): Path to the YAML config file.

    Returns:
        dict: Configuration parameters.
    """
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config


# Stitching Functions

def perform_manual_warping(cropped_img1_path, cropped_img2_path, output_dir):
    """
    Manually perform warping by creating masks and blending the cropped images horizontally.

    Parameters:
        cropped_img1_path (str): Path to the cropped left image.
        cropped_img2_path (str): Path to the cropped right image.
        output_dir (str): Directory to save warped images and masks.
    """
    img1 = cv2.imread(cropped_img1_path)
    img2 = cv2.imread(cropped_img2_path)

    if img1 is None:
        raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}")
    if img2 is None:
        raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}")

    # Generate mask1
    height1, width1, _ = img1.shape
    split1 = int(width1 * 0.9)
    mask1 = np.zeros_like(img1, dtype=np.uint8)
    mask1[:, :split1, :] = 255
    mask1[:, split1:, :] = 0 

    # Generate mask2
    height2, width2, _ = img2.shape
    split2 = int(width2 * 0.1)
    mask2 = np.zeros_like(img2, dtype=np.uint8)
    mask2[:, :split2, :] = 0
    mask2[:, split2:, :] = 255

    # Create warped images
    warp1 = cv2.bitwise_and(img1, mask1)
    warp2 = cv2.bitwise_and(img2, mask2)

    # Save masks and warped images
    os.makedirs(output_dir, exist_ok=True)
    cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1)
    cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2)
    cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1)
    cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2)
    logging.info(f"Manual warped images and masks saved in {output_dir}")


def perform_manual_warping_vertical(cropped_img1_path, cropped_img2_path, output_dir):
    """
    Manually perform warping by creating masks and blending the cropped images vertically.

    Parameters:
        cropped_img1_path (str): Path to the cropped bottom image.
        cropped_img2_path (str): Path to the cropped top image.
        output_dir (str): Directory to save warped images and masks.
    """

    img1 = cv2.imread(cropped_img1_path)
    img2 = cv2.imread(cropped_img2_path)

    if img1 is None:
        raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}")
    if img2 is None:
        raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}")

    # Generate mask1
    height1, width1, _ = img1.shape
    split1 = int(height1 * 0.9)
    mask1 = np.zeros_like(img1, dtype=np.uint8)
    mask1[:split1, :, :] = 255
    mask1[split1:, :, :] = 0 

    # Generate mask2
    height2, width2, _ = img2.shape
    split2 = int(height2 * 0.1)
    mask2 = np.zeros_like(img2, dtype=np.uint8)
    mask2[:split2, :, :] = 0
    mask2[split2:, :, :] = 255 

    # Create warped images
    warp1 = cv2.bitwise_and(img1, mask1)
    warp2 = cv2.bitwise_and(img2, mask2)

    # Save masks and warped images
    os.makedirs(output_dir, exist_ok=True)
    cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1)
    cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2)
    cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1)
    cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2)
    logging.info(f"Manual warped images and masks saved in {output_dir}")


def perform_composition(net, input_dir, output_dir, config_options):
    """
    Perform composition on the warped outputs to generate the stitched image.

    Parameters:
        net (CompositionNetwork): The initialized composition network.
        input_dir (str): Directory containing warped images and masks.
        output_dir (str): Directory to save the composed image.
        config_options (dict): Configuration options from config.yaml.
    """
    # Load warped images
    warp1 = cv2.imread(os.path.join(input_dir, "warp1.jpg"))
    warp2 = cv2.imread(os.path.join(input_dir, "warp2.jpg"))

    if warp1 is None or warp2 is None:
        raise FileNotFoundError("Warped images not found in the input directory.")

    # Convert to float32 and normalize
    warp1 = warp1.astype(dtype=np.float32)
    warp1 = (warp1 / 127.5) - 1.0
    warp1 = np.transpose(warp1, [2, 0, 1])
    warp1_tensor = torch.tensor(warp1).unsqueeze(0)

    warp2 = warp2.astype(dtype=np.float32)
    warp2 = (warp2 / 127.5) - 1.0
    warp2 = np.transpose(warp2, [2, 0, 1])
    warp2_tensor = torch.tensor(warp2).unsqueeze(0)

    # Load masks
    mask1 = cv2.imread(os.path.join(input_dir, "mask1.jpg"))
    mask2 = cv2.imread(os.path.join(input_dir, "mask2.jpg"))

    if mask1 is None or mask2 is None:
        raise FileNotFoundError("Masks not found in the input directory.")

    mask1 = mask1.astype(dtype=np.float32) / 255.0
    mask1 = np.transpose(mask1, [2, 0, 1])
    mask1_tensor = torch.tensor(mask1).unsqueeze(0)

    mask2 = mask2.astype(dtype=np.float32) / 255.0
    mask2 = np.transpose(mask2, [2, 0, 1])
    mask2_tensor = torch.tensor(mask2).unsqueeze(0)

    if torch.cuda.is_available():
        warp1_tensor = warp1_tensor.cuda()
        warp2_tensor = warp2_tensor.cuda()
        mask1_tensor = mask1_tensor.cuda()
        mask2_tensor = mask2_tensor.cuda()
        net = net.cuda()

    # Perform composition
    net.eval()
    with torch.no_grad():
        batch_out = composition_build_model(
            net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor
        )
    stitched_image = batch_out["stitched_image"]

    logging.info(f"Stitched image shape: {stitched_image.shape}")

    # Save the composed image
    os.makedirs(output_dir, exist_ok=True)
    stitched_image_np = (
        (stitched_image[0] + 1) * 127.5
    ).cpu().numpy().transpose(1, 2, 0)
    cv2.imwrite(
        os.path.join(output_dir, "composition.jpg"), stitched_image_np.astype(np.uint8)
    )
    logging.info(f"Composed image saved in {output_dir}")

    # Save learned masks if enabled
    if config_options.get('save_learned_masks', False):
        learned_mask1 = (batch_out['learned_mask1'][0] * 255).cpu().numpy().transpose(1, 2, 0)
        learned_mask2 = (batch_out['learned_mask2'][0] * 255).cpu().numpy().transpose(1, 2, 0)

        cv2.imwrite(os.path.join(input_dir, "learn_mask1.jpg"), learned_mask1.astype(np.uint8))
        cv2.imwrite(os.path.join(input_dir, "learn_mask2.jpg"), learned_mask2.astype(np.uint8))
        logging.info(f"Learned masks saved in {input_dir}/learn_mask1.jpg and {input_dir}/learn_mask2.jpg")

    # Save composition_color.jpg if enabled
    if config_options.get('save_composition_color', False):
        s1 = ((warp1_tensor[0] + 1) * 127.5 * batch_out['learned_mask1'][0]).cpu().detach().numpy().transpose(1, 2, 0)
        s2 = ((warp2_tensor[0] + 1) * 127.5 * batch_out['learned_mask2'][0]).cpu().detach().numpy().transpose(1, 2, 0)
        fusion = np.zeros((warp1_tensor.shape[2], warp1_tensor.shape[3], 3), np.uint8)
        fusion[..., 0] = s2[..., 0]
        fusion[..., 1] = (s1[..., 1] * 0.5 + s2[..., 1] * 0.5).astype(np.uint8)
        fusion[..., 2] = s1[..., 2]
        composition_color_path = os.path.join(output_dir, "composition_color.jpg")
        cv2.imwrite(composition_color_path, fusion)
        logging.info(f"Composition color image saved at {composition_color_path}")


def stitch_pair(left_image, right_image, tx_multiplier, warp_net, composition_net, output_dir, level, group_num, pair_num, config_options):
    """
    Stitch a pair of images horizontally using manual warping and composition.

    Parameters:
        left_image (str): Path to the left image.
        right_image (str): Path to the right image.
        tx_multiplier (int): Multiplier for the base translation in the x-direction.
        warp_net (WarpNetwork): Initialized warping network (Not used in manual warping).
        composition_net (CompositionNetwork): Initialized composition network.
        output_dir (str): Directory to save intermediate and final outputs.
        level (int): Current stitching level.
        group_num (str): Group number for unique identification.
        pair_num (int): Pair number within the group.
        config_options (dict): Configuration options from config.yaml.

    Returns:
        str: Path to the stitched image.
    """
    base_tx = config_options['tx_level']
    tx = tx_multiplier * base_tx

    # Logging the images and tx value
    logging.info(f"Stitching pair {pair_num} at level {level}, group {group_num}")
    logging.info(f"Left Image: {left_image}")
    logging.info(f"Right Image: {right_image}")
    logging.info(f"Effective tx used: {tx}")

    # Perform cropping
    cropped = crop_images_specific(left_image, right_image, tx, 0)
    cropped_img1 = cropped["cropped_img1"]
    cropped_img2 = cropped["cropped_img2"]
    unused_img1 = cropped["unused_img1"]
    unused_img2 = cropped["unused_img2"]

    # Save cropped images for reference
    cropped_dir = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}", "cropped")
    os.makedirs(cropped_dir, exist_ok=True)
    cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1)
    cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2)
    cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1)
    cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2)
    logging.info(f"Cropped images saved for level {level} group {group_num} pair {pair_num} in {cropped_dir}")

    # Paths to the cropped images
    cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg")
    cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg")

    # Perform manual warping on cropped regions
    perform_manual_warping(
        cropped_img1_path, cropped_img2_path, cropped_dir
    )

    # Perform composition on warped outputs
    perform_composition(composition_net, cropped_dir, cropped_dir, config_options)

    # Load the composed overlapping region
    composed_overlap_path = os.path.join(cropped_dir, "composition.jpg")
    composed_overlap = cv2.imread(composed_overlap_path)

    if composed_overlap is None:
        logging.error(f"Composition failed for level {level} group {group_num} pair {pair_num}. Composed image not found at {composed_overlap_path}")
        return None

    # Concatenate unused parts with composed overlap horizontally
    try:
        final_image = np.hstack((unused_img1, composed_overlap, unused_img2))
    except Exception as e:
        logging.error(f"Error during image concatenation for level {level} group {group_num} pair {pair_num}: {e}")
        return None

    logging.info(f"Final stitched image shape for level {level} group {group_num} pair {pair_num}: {final_image.shape}")

    # Save the final stitched image
    stitched_image_path = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}_final.jpg")
    os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True)
    cv2.imwrite(stitched_image_path, final_image)
    logging.info(f'Final stitched image saved for level {level} group {group_num} pair {pair_num} at {stitched_image_path}')

    return stitched_image_path


def stitch_pair_vertical(bottom_image, top_image, ty_multiplier, warp_net, composition_net, output_dir, level, pair_num, config_options):
    """
    Stitch a pair of images vertically using manual warping and composition.

    Parameters:
        bottom_image (str): Path to the bottom image.
        top_image (str): Path to the top image.
        ty_multiplier (int): Multiplier for the base translation in the y-direction.
        warp_net (WarpNetwork): Initialized warping network (Not used in manual warping).
        composition_net (CompositionNetwork): Initialized composition network.
        output_dir (str): Directory to save intermediate and final outputs.
        level (int): Current stitching level.
        pair_num (int): Pair number within the current level.
        config_options (dict): Configuration options from config.yaml.

    Returns:
        str: Path to the stitched image.
    """
    base_ty = config_options['ty_level']
    ty = ty_multiplier * base_ty

    # Logging the images and ty value
    logging.info(f"Stitching pair {pair_num} at level {level}")
    logging.info(f"Bottom Image: {bottom_image}")
    logging.info(f"Top Image: {top_image}")
    logging.info(f"Effective ty used: {ty}")

    # Perform cropping
    cropped = crop_images_specific_vertical(bottom_image, top_image, ty)
    cropped_img1 = cropped["cropped_img1"]  # Bottom cropped
    cropped_img2 = cropped["cropped_img2"]  # Top cropped
    unused_img1 = cropped["unused_img1"]    # Bottom unused
    unused_img2 = cropped["unused_img2"]    # Top unused

    # Save cropped images for reference
    cropped_dir = os.path.join(output_dir, f"level{level}_pair{pair_num}", "cropped")
    os.makedirs(cropped_dir, exist_ok=True)
    cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1)
    cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2)
    cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1)
    cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2)
    logging.info(f"Cropped images saved for level {level} pair {pair_num} in {cropped_dir}")

    # Paths to the cropped images
    cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg")
    cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg")

    # Perform manual warping on cropped regions
    perform_manual_warping_vertical(
        cropped_img1_path, cropped_img2_path, cropped_dir
    )

    # Perform composition on warped outputs
    perform_composition(composition_net, cropped_dir, cropped_dir, config_options)

    # Load the composed overlapping region
    composed_overlap_path = os.path.join(cropped_dir, "composition.jpg")
    composed_overlap = cv2.imread(composed_overlap_path)

    if composed_overlap is None:
        logging.error(f"Composition failed for level {level} pair {pair_num}. Composed image not found at {composed_overlap_path}")
        return None

    # Concatenate unused parts with composed overlap vertically
    try:
        # Assuming all images have the same width after cropping
        final_image = np.vstack((unused_img1, composed_overlap, unused_img2))
    except Exception as e:
        logging.error(f"Error during image concatenation for level {level} pair {pair_num}: {e}")
        return None

    logging.info(f"Final stitched image shape for level {level} pair {pair_num}: {final_image.shape}")

    # Save the final stitched image
    stitched_image_path = os.path.join(output_dir, f"level{level}_pair{pair_num}_final.jpg")
    os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True)
    cv2.imwrite(stitched_image_path, final_image)
    logging.info(f'Final stitched image saved for level {level} pair {pair_num} at {stitched_image_path}')

    return stitched_image_path


def tree_stitch_images(
    images,
    stride_multiplier,
    stitch_func,
    warp_net,
    composition_net,
    output_dir,
    level,
    group_num,
    config_options,
    direction='horizontal'
):
    """
    Perform tree-like stitching on a list of images until only one image remains.

    Parameters:
        images (list): List of image paths to stitch.
        stride_multiplier (int): Stride multiplier based on the configuration (e.g., row_stride).
        stitch_func (function): Stitching function to use (stitch_pair or stitch_pair_vertical).
        warp_net (WarpNetwork): Warp network instance.
        composition_net (CompositionNetwork): Composition network instance.
        output_dir (str): Directory to save outputs.
        level (int): Starting stitching level.
        group_num (str): Group identifier.
        config_options (dict): Configuration options.
        direction (str): 'horizontal' or 'vertical'.

    Returns:
        str: Path to the final stitched image.
    """

    images_to_stitch = images.copy()
    current_level = 0

    # Calculate the largest power of two less than or equal to the number of images
    num_images = len(images_to_stitch)
    if num_images < 2:
        logging.warning("Not enough images to stitch.")
        return images_to_stitch[0] if images_to_stitch else None

    largest_power = 2**int(math.log2(num_images))
    selected_images = images_to_stitch[:largest_power]
    omitted_images = images_to_stitch[largest_power:]

    if omitted_images:
        logging.info(f"Omitting {len(omitted_images)} redundant image(s) to make the count a power of two.")

    images_to_stitch = selected_images

    while len(images_to_stitch) > 1:
        stitched_images = []
        # Calculate multiplier based on stride and current level
        multiplier = stride_multiplier * (2 ** current_level)
        logging.info(f"Stitching Level {level + current_level}: Multiplier = {multiplier}")

        for i in range(0, len(images_to_stitch), 2):
            img1 = images_to_stitch[i]
            if i + 1 < len(images_to_stitch):
                img2 = images_to_stitch[i + 1]
                if direction == 'horizontal':
                    stitched_image = stitch_func(
                        left_image=img1,
                        right_image=img2,
                        tx_multiplier=multiplier,
                        warp_net=warp_net,
                        composition_net=composition_net,
                        output_dir=output_dir,
                        level=level + current_level,
                        group_num=group_num,
                        pair_num=(i // 2) + 1,
                        config_options=config_options
                    )
                else:
                    stitched_image = stitch_func(
                        bottom_image=img1,
                        top_image=img2,
                        ty_multiplier=multiplier,
                        warp_net=warp_net,
                        composition_net=composition_net,
                        output_dir=output_dir,
                        level=level + current_level,
                        pair_num=(i // 2) + 1,
                        config_options=config_options
                    )
                if stitched_image:
                    stitched_images.append(stitched_image)
                else:
                    logging.error(f"Failed to stitch {img1} and {img2}")
                    return None
            else:
                # If odd number of images, carry the last one forward
                stitched_images.append(img1)
                logging.info(f"Odd image {img1} carried forward to the next level.")

        images_to_stitch = stitched_images
        current_level += 1

    logging.info(f"Completed stitching up to level {level + current_level - 1}")
    return images_to_stitch[0] if images_to_stitch else None


def select_power_of_two_rows(rows_to_stitch):
    """
    Selects the largest power-of-two number of rows from the given list.

    Parameters:
        rows_to_stitch (list): List of row indices to stitch.

    Returns:
        list: Subset of rows_to_stitch with length as a power of two.
    """
    num_rows = len(rows_to_stitch)
    if num_rows == 0:
        return []
    largest_power = 2**int(math.log2(num_rows))
    selected_rows = rows_to_stitch[:largest_power]
    omitted_rows = rows_to_stitch[largest_power:]
    if omitted_rows:
        logging.info(f"Omitting {len(omitted_rows)} redundant row(s) to make the count a power of two: {selected_rows}")
    return selected_rows


def sanitize_filename(filename):
    """
    Sanitize the filename by replacing or removing invalid characters.
    """
    # Replace spaces with underscores
    filename = filename.replace(" ", "_")
    # Replace colons with hyphens
    filename = filename.replace(":", "-")
    # Remove any other invalid characters (for cross-platform compatibility)
    # Allow only alphanumerics, underscores, hyphens, and dots
    filename = re.sub(r'[^A-Za-z0-9._-]', '', filename)
    return filename


def cleanup_unused_resources(scan_path, used_horizontal_scans, used_images_per_scan):
    """
    Delete unused horizontal_scanX directories and unused images within used directories.

    Parameters:
        scan_path (str): Path to the current scan directory.
        used_horizontal_scans (set): Set of used horizontal_scanX directory names.
        used_images_per_scan (dict): Dictionary mapping horizontal_scanX to set of used image filenames.
    """
    # List all horizontal_scanX directories
    all_horizontal_scans = [
        d for d in os.listdir(scan_path)
        if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
    ]

    for hs_dir in all_horizontal_scans:
        hs_path = os.path.join(scan_path, hs_dir)
        if hs_dir not in used_horizontal_scans:
            # Delete the entire directory
            try:
                shutil.rmtree(hs_path)
                logging.info(f"Deleted unused directory: {hs_path}")
            except Exception as e:
                logging.error(f"Failed to delete directory {hs_path}: {e}")
                logging.error(traceback.format_exc())
        else:
            # Delete unused images within the used horizontal_scanX directory
            used_images = used_images_per_scan.get(hs_dir, set())
            all_images = [
                f for f in os.listdir(hs_path)
                if f.lower().endswith('.jpg')
            ]
            for img in all_images:
                if img not in used_images:
                    img_path = os.path.join(hs_path, img)
                    try:
                        os.remove(img_path)
                        logging.info(f"Deleted unused image: {img_path}")
                    except Exception as e:
                        logging.error(f"Failed to delete image {img_path}: {e}")
                        logging.error(traceback.format_exc())


def cleanup_intermediate_outputs(scan_path):
    """
    Delete directories and files starting with 'level' (e.g., 'level1', 'level2', etc.) within the scan directory.

    Parameters:
        scan_path (str): Path to the current scan directory.
    """
    # List all items in the scan directory
    for item in os.listdir(scan_path):
        item_path = os.path.join(scan_path, item)
        
        # Check if the item is a directory starting with 'level'
        if os.path.isdir(item_path) and re.match(r'^level\d+', item):
            try:
                shutil.rmtree(item_path)
                logging.info(f"Deleted intermediate directory: {item_path}")
            except Exception as e:
                logging.error(f"Failed to delete intermediate directory {item_path}: {e}")
                logging.error(traceback.format_exc())
        
        # Check if the item is a file starting with 'level'
        elif os.path.isfile(item_path) and re.match(r'^level\d+', item):
            try:
                os.remove(item_path)
                logging.info(f"Deleted intermediate file: {item_path}")
            except Exception as e:
                logging.error(f"Failed to delete intermediate file {item_path}: {e}")
                logging.error(traceback.format_exc())


def main():

    parser = argparse.ArgumentParser(description="Image Stitching Script")
    parser.add_argument(
        '--bag_path',
        type=str,
        required=True,
        help='Path to the ROS bag file.'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Directory to save the output images.'
    )
    parser.add_argument(
        '--config_paths',
        type=str,
        required=True,
        nargs='+',
        help='Paths to the YAML configuration files for each hive and XY table.'
    )
    args = parser.parse_args()

    # Convert to absolute paths
    BAG_PATH = os.path.abspath(args.bag_path)
    OUTPUT_BASE_DIR = os.path.abspath(args.output_dir)
    CONFIG_PATHS = [os.path.abspath(path) for path in args.config_paths]

    # Define 'Final_Outputs' directory
    final_outputs_dir = os.path.join(OUTPUT_BASE_DIR, "Final_Outputs")
    os.makedirs(final_outputs_dir, exist_ok=True)
    logging.info(f"Final outputs will be saved in {final_outputs_dir}")

    # Load and validate each configuration
    configs = []
    for config_path in CONFIG_PATHS:
        if not os.path.exists(config_path):
            logging.error(f"Configuration file not found at {config_path}")
            sys.exit(1)
        config = load_config(config_path)
        # Validate required fields
        if 'hive_id' not in config or 'xy_id' not in config:
            logging.error(f"Configuration file {config_path} missing 'hive_id' or 'xy_id'.")
            sys.exit(1)
        configs.append(config)
        logging.info(f"Loaded configuration from {config_path}")

    try:

        # Environment and Network Initialization

        # Set CUDA environment variables
        os.environ["CUDA_DEVICES_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = GPU_DEVICE

        # Add project root to sys.path
        project_root = os.path.dirname(os.path.abspath(__file__))
        sys.path.insert(0, project_root)

        logging.info("Initializing WarpNetwork and CompositionNetwork...")
        # Initialize networks
        warp_net = WarpNetwork()
        composition_net = CompositionNetwork()
        logging.info("Networks initialized successfully.")

        # Load pre-trained models for composition_net
        if os.path.exists(COMPOSITION_MODEL_PATH):
            logging.info(f"Loading composition model from {COMPOSITION_MODEL_PATH}...")
            checkpoint = torch.load(
                COMPOSITION_MODEL_PATH, 
                map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            )
            composition_net.load_state_dict(checkpoint["model"])
            logging.info(f"Loaded composition model from {COMPOSITION_MODEL_PATH}")
        else:
            logging.error(f"No composition model found at {COMPOSITION_MODEL_PATH}")
            sys.exit(1)

        # Move networks to GPU if available
        if torch.cuda.is_available():
            warp_net = warp_net.cuda()
            composition_net = composition_net.cuda()
            logging.info("Networks moved to GPU.")
        else:
            logging.info("CUDA not available. Networks are on CPU.")

        # Bag Parsing and Image Processing

        logging.info("\n========================")
        logging.info("Starting Bag Parsing and Processing")
        logging.info("========================\n")

        for config in configs:
            hive_id = config.get('hive_id')
            xy_id = config.get('xy_id')

            logging.info(f"\n=== Processing Hive {hive_id}, XY Table {xy_id} ===")
            # Define paths specific to hive and XY table
            hive_xy_output_dir = os.path.join(OUTPUT_BASE_DIR, f"hive_{hive_id}", f"xy_{xy_id}")
            os.makedirs(hive_xy_output_dir, exist_ok=True)
            logging.info(f"Output directory for Hive {hive_id}, XY {xy_id}: {hive_xy_output_dir}")

            # Extract configurations from each config file
            stitching_mode = config.get('stitching', {}).get('mode', 'linear').lower()
            translations = config.get('translations', {})
            tx_level = translations.get('tx_level', -320)
            ty_level = translations.get('ty_level', -330)
            config_options = config.get('options', {})
            config_options['tx_level'] = tx_level
            config_options['ty_level'] = ty_level

            # Retrieve 'resize_shape' from config
            resize_shape = config_options.get('resize_shape', None)
            if resize_shape is not None:
                if isinstance(resize_shape, list) and len(resize_shape) == 2:
                    resize_shape = tuple(resize_shape)
                    logging.info(f"Resize shape set to: {resize_shape}")
                else:
                    logging.warning(f"Invalid resize_shape format in config: {resize_shape}. Setting to None.")
                    resize_shape = None
            config_options['resize_shape'] = resize_shape

            # Extract max_images_to_process from stitching
            MAX_IMAGES_TO_PROCESS = config.get('stitching', {}).get('max_images_to_process', 120)

            if stitching_mode == 'linear':
                linear_params = config.get('stitching', {}).get('linear', {})
                COLUMNS_TO_STITCH = linear_params.get('columns_to_stitch', [])
                ROWS_TO_STITCH = linear_params.get('rows_to_stitch', [])
            elif stitching_mode == 'tree':
                tree_params = config.get('stitching', {}).get('tree', {})
                column_params = tree_params.get('columns', {})
                row_params = tree_params.get('rows', {})
            else:
                logging.error(f"Invalid stitching mode: {stitching_mode}. Choose 'linear' or 'tree'.")
                continue  # Skip this config and proceed to the next

            # Load XY positions for this hive and XY table
            xy_positions = load_xy_positions_specific(BAG_PATH, hive_id, xy_id)
            if not xy_positions:
                logging.warning(f"No XY positions found for Hive {hive_id}, XY {xy_id}. Skipping.")
                continue

            # Process images from the bag for this hive and XY table
            process_images(
                xy_positions, 
                BAG_PATH, 
                output_dir=hive_xy_output_dir, 
                threshold=0.001, 
                significant_move=0.1, 
                resize_shape=config_options.get('resize_shape'),  # Pass from config
                max_images=MAX_IMAGES_TO_PROCESS,
                apply_color_filter=config_options.get('apply_color_filter', False),
                compressed_topic=f"/hive_{hive_id}/xy_{xy_id}/bee_camera/compressed",
                xy_position_topic=f"/hive_{hive_id}/xy_{xy_id}/xy_position"
            )

            logging.info("\n========================")
            logging.info("Starting Cleanup Phase")
            logging.info("========================\n")
            cleanup_redundant_images(hive_xy_output_dir)

            logging.info(f"\nAll images processed and saved in the directory: {hive_xy_output_dir}")

            logging.info("\n========================")
            logging.info("Starting Stitching Phase")
            logging.info("========================\n")

            # Iterate over each scan directory within the hive and XY table
            for scan_folder in os.listdir(hive_xy_output_dir):
                scan_path = os.path.join(hive_xy_output_dir, scan_folder)

                if not os.path.isdir(scan_path):
                    logging.info(f"Skipping non-directory: {scan_path}")
                    continue

                try:
                    logging.info(f"\n--- Stitching Scan: {scan_folder} ---")

                    if stitching_mode == 'linear':
                        # ===============================
                        # Linear Stitching Mode
                        # ===============================
                        logging.info("Mode: Linear Stitching")

                        logging.info(f"Linear Stitching Mode: Columns to stitch: {COLUMNS_TO_STITCH}")
                        logging.info(f"Linear Stitching Mode: Rows to stitch: {ROWS_TO_STITCH}")

                        # Find all horizontal_scanX subfolders
                        horizontal_scan_dirs = sorted([
                            d for d in os.listdir(scan_path)
                            if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
                        ], key=lambda x: int(x.replace("horizontal_scan", "")))  # Sort numerically

                        if not horizontal_scan_dirs:
                            logging.info(f"No 'horizontal_scanX' folders found in {scan_path}. Skipping stitching.")
                            continue

                        # Select only the specified rows for linear stitching
                        selected_horizontal_scan_dirs = [
                            d for d in horizontal_scan_dirs 
                            if int(d.replace("horizontal_scan", "")) in ROWS_TO_STITCH
                        ]

                        if not selected_horizontal_scan_dirs:
                            logging.warning(f"No 'horizontal_scanX' folders match ROWS_TO_STITCH {ROWS_TO_STITCH} in {scan_path}. Skipping stitching.")
                            continue

                        horizontal_stitched_paths = []
                        # Dictionaries to track used horizontal_scanX directories and their used images
                        used_horizontal_scans = set()
                        used_images_per_scan = {}

                        for horizontal_scan_dir in selected_horizontal_scan_dirs:
                            images_dir = os.path.join(scan_path, horizontal_scan_dir)
                            scan_images = sorted([
                                f for f in os.listdir(images_dir) 
                                if f.lower().endswith('.jpg')
                            ], key=lambda x: int(x.split('_')[1].split('.')[0]))  # Sort by image number

                            logging.info(f"Processing {horizontal_scan_dir} with {len(scan_images)} images.")

                            if len(scan_images) < 2:
                                logging.warning(f"Warning: '{horizontal_scan_dir}' does not contain enough images for stitching. Skipping this horizontal scan.")
                                continue

                            # Group images based on COLUMNS_TO_STITCH
                            try:
                                selected_columns = COLUMNS_TO_STITCH
                                # Validate column indices
                                for col in selected_columns:
                                    if col < 1 or col > len(scan_images):
                                        raise IndexError(f"Column index {col} is out of range for the current horizontal scan with {len(scan_images)} images.")

                                # Sort the selected columns to ensure correct order
                                selected_columns_sorted = sorted(selected_columns)
                                selected_images_sorted = [os.path.join(images_dir, scan_images[col-1]) for col in selected_columns_sorted]
                            except Exception as e:
                                logging.error(f"Error selecting columns to stitch in '{horizontal_scan_dir}': {e}")
                                continue

                            # Track used horizontal_scanX directories and images
                            hs_dir_name = horizontal_scan_dir
                            used_horizontal_scans.add(hs_dir_name)
                            if hs_dir_name not in used_images_per_scan:
                                used_images_per_scan[hs_dir_name] = set()
                            for img_path in selected_images_sorted:
                                img_filename = os.path.basename(img_path)
                                used_images_per_scan[hs_dir_name].add(img_filename)

                            # Stitch selected images sequentially with dynamic tx_multiplier
                            current_stitched = selected_images_sorted[0]
                            for idx in range(1, len(selected_images_sorted)):
                                previous_column = selected_columns_sorted[idx-1]
                                current_column = selected_columns_sorted[idx]
                                difference = current_column - previous_column  # Calculate column difference
                                tx_multiplier = difference  # Set tx_multiplier based on column difference

                                pair_num = idx  # Pair number within the group

                                logging.info(f"Stitching pair {pair_num}: Column {previous_column} and Column {current_column} with tx_multiplier={tx_multiplier}")

                                stitched_image = stitch_pair(
                                    left_image=current_stitched,
                                    right_image=selected_images_sorted[idx],
                                    tx_multiplier=tx_multiplier,  # Dynamic multiplier based on column difference
                                    warp_net=warp_net,
                                    composition_net=composition_net,
                                    output_dir=scan_path,
                                    level=1,  # Level 1 for horizontal stitching
                                    group_num=horizontal_scan_dir.replace("horizontal_scan", ""),  # Numeric identifier
                                    pair_num=pair_num,  # Pair number within the group
                                    config_options=config_options  # Updated config_options with tx_level and ty_level
                                )

                                if stitched_image:
                                    current_stitched = stitched_image
                                    logging.info(f"Updated current_stitched to: {current_stitched}")
                                else:
                                    logging.error(f"Failed to stitch image pair {current_stitched} and {selected_images_sorted[idx]} in '{horizontal_scan_dir}'.")
                                    break

                            if stitched_image:
                                horizontal_stitched_paths.append(stitched_image)
                            else:
                                logging.error(f"Horizontal stitching failed for '{horizontal_scan_dir}' in scan: {scan_folder}")

                        if not horizontal_stitched_paths:
                            logging.error(f"No horizontal stitching results available for scan: {scan_folder}. Skipping vertical stitching.")
                            continue

                        logging.info(f"Starting vertical stitching for scan: {scan_folder}")

                        # Initialize current_panoramic with the first stitched image
                        current_panoramic = horizontal_stitched_paths[0]
                        logging.info(f"Initialized current_panoramic with: {current_panoramic}")

                        # Iterate over the remaining stitched images
                        for i in range(1, len(horizontal_stitched_paths)):

                            previous_row = ROWS_TO_STITCH[i-1]
                            current_row = ROWS_TO_STITCH[i]
                            row_diff = current_row - ROWS_TO_STITCH[0]  # Maintain original row_diff calculation
                            ty_multiplier = row_diff  # Set ty_multiplier based on row difference

                            pair_num = i  # Pair number within vertical stitching

                            logging.info(f"Stitching vertically: {current_panoramic} with {horizontal_stitched_paths[i]} (ty_multiplier={ty_multiplier})")

                            stitched_image = stitch_pair_vertical(
                                bottom_image=current_panoramic,
                                top_image=horizontal_stitched_paths[i],
                                ty_multiplier=ty_multiplier,  # Dynamic multiplier based on row difference
                                warp_net=warp_net,
                                composition_net=composition_net,
                                output_dir=scan_path,
                                level=2,  # Level 2 for vertical stitching
                                pair_num=pair_num,  # Pair number within vertical stitching
                                config_options=config_options
                            )

                            if stitched_image:
                                current_panoramic = stitched_image  # Update for cumulative stitching
                                logging.info(f"Updated current_panoramic to: {current_panoramic}")
                            else:
                                logging.error(f"Failed to vertically stitch image pair {current_panoramic} and {horizontal_stitched_paths[i]} in scan: {scan_folder}.")
                                break

                        # Save the final panoramic image
                        final_panoramic_stitched = current_panoramic
                        logging.info(f"\n--- Stitching Completed for Scan: {scan_folder} ---")
                        logging.info(f"Final vertically stitched (panoramic) image saved at: {final_panoramic_stitched}")

                        # **Save the Final Stitched Image to Final_Outputs**
                        scan_folder_name = scan_folder  # e.g., "2024-11-23 21-48-23_298"
                        sanitized_scan_folder_name = sanitize_filename(scan_folder_name)
                        new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg"
                        target_path = os.path.join(final_outputs_dir, new_filename)

                        try:
                            shutil.copy(final_panoramic_stitched, target_path)
                            logging.info(f"Final stitched image copied to {target_path}")
                        except Exception as e:
                            logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}")
                            logging.error(traceback.format_exc())

                        # **Cleanup Unused Resources**
                        cleanup_option = config_options.get('cleanup_unused', False)
                        if cleanup_option:
                            logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}")
                            cleanup_unused_resources(
                                scan_path=scan_path,
                                used_horizontal_scans=used_horizontal_scans,
                                used_images_per_scan=used_images_per_scan
                            )
                            logging.info(f"Cleanup completed for scan: {scan_folder}")
                        else:
                            logging.info("Cleanup of unused resources is disabled in the configuration.")

                        # **Cleanup Intermediate Outputs**
                        cleanup_intermediate_option = config_options.get('cleanup_intermediate', False)
                        if cleanup_intermediate_option:
                            logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}")
                            cleanup_intermediate_outputs(scan_path=scan_path)
                            logging.info(f"Intermediate cleanup completed for scan: {scan_folder}")
                        else:
                            logging.info("Cleanup of intermediate outputs is disabled in the configuration.")

                    elif stitching_mode == 'tree':
                        # ===============================
                        # Tree-like Stitching Mode
                        # ===============================
                        logging.info("Mode: Tree-like Stitching")

                        logging.info(f"Tree-like Stitching Mode: Columns Params: {column_params}")
                        logging.info(f"Tree-like Stitching Mode: Rows Params: {row_params}")

                        # Generate rows to stitch based on start, end, row_stride
                        start_row = row_params.get('start', 1)
                        end_row = row_params.get('end', len([
                            d for d in os.listdir(scan_path)
                            if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
                        ]))
                        row_stride = row_params.get('row_stride', 1)
                        rows_to_stitch = list(range(start_row, end_row + 1, row_stride))

                        # Adjust to power-of-two number of rows
                        rows_to_stitch_power_two = select_power_of_two_rows(rows_to_stitch)

                        # Collect image paths based on columns and rows
                        selected_images = []
                        # Dictionaries to track used horizontal_scanX directories and their used images
                        used_horizontal_scans = set()
                        used_images_per_scan = {}

                        for row in rows_to_stitch_power_two:
                            horizontal_scan_dir = f"horizontal_scan{row}"
                            images_dir = os.path.join(scan_path, horizontal_scan_dir)
                            if not os.path.isdir(images_dir):
                                logging.warning(f"Directory {images_dir} does not exist. Skipping.")
                                continue

                            scan_images = sorted([
                                f for f in os.listdir(images_dir)
                                if f.lower().endswith('.jpg')
                            ], key=lambda x: int(x.split('_')[1].split('.')[0]))  # Sort by image number

                            # Select columns based on start, end, col_stride
                            start_col = column_params.get('start', 1)
                            end_col = column_params.get('end', len(scan_images))
                            col_stride = column_params.get('col_stride', 1)
                            columns_to_stitch = list(range(start_col, end_col + 1, col_stride))

                            logging.info(f"Row {row}: Columns to stitch: {columns_to_stitch}")

                            selected_columns_sorted = sorted(columns_to_stitch)
                            selected_images_sorted = [
                                os.path.join(images_dir, scan_images[col-1])
                                for col in selected_columns_sorted
                                if 1 <= col <= len(scan_images)
                            ]

                            if not selected_images_sorted:
                                logging.warning(f"No valid columns found for row {row}. Skipping.")
                                continue

                            # Track used horizontal_scanX directories and images
                            hs_dir_name = horizontal_scan_dir
                            used_horizontal_scans.add(hs_dir_name)
                            if hs_dir_name not in used_images_per_scan:
                                used_images_per_scan[hs_dir_name] = set()
                            for img_path in selected_images_sorted:
                                img_filename = os.path.basename(img_path)
                                used_images_per_scan[hs_dir_name].add(img_filename)

                            # Perform tree-like stitching on selected images
                            logging.info(f"Performing tree-like stitching on row {row} with {len(selected_images_sorted)} images.")
                            stitched_row_image = tree_stitch_images(
                                images=selected_images_sorted,
                                stride_multiplier=col_stride,  # Use col_stride as stride_multiplier
                                stitch_func=stitch_pair,
                                warp_net=warp_net,
                                composition_net=composition_net,
                                output_dir=scan_path,
                                level=1,  # Starting level for horizontal stitching
                                group_num=f"row{row}",
                                config_options=config_options,
                                direction='horizontal'
                            )
                            if stitched_row_image:
                                selected_images.append(stitched_row_image)
                            else:
                                logging.error(f"Tree-like stitching failed for row {row}.")
                                continue

                        if len(selected_images) < 2:
                            logging.warning(f"Not enough stitched rows to perform vertical stitching in scan {scan_folder}. Skipping.")
                            continue

                        # Perform tree-like vertical stitching on the stitched rows
                        logging.info("Performing tree-like stitching vertically on stitched rows.")
                        final_panoramic_image = tree_stitch_images(
                            images=selected_images,
                            stride_multiplier=row_stride,  # Use row_stride as stride_multiplier
                            stitch_func=stitch_pair_vertical,
                            warp_net=warp_net,
                            composition_net=composition_net,
                            output_dir=scan_path,
                            level=1,  # Starting level for vertical stitching
                            group_num="vertical",
                            config_options=config_options,
                            direction='vertical'
                        )
                        if final_panoramic_image:
                            logging.info(f"Final tree-like stitched panoramic image saved at: {final_panoramic_image}")

                            # **Save the Final Stitched Image to Final_Outputs**
                            scan_folder_name = scan_folder  # e.g., "2024-11-23 21-48-23_298"
                            sanitized_scan_folder_name = sanitize_filename(scan_folder_name)
                            new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg"
                            target_path = os.path.join(final_outputs_dir, new_filename)

                            try:
                                shutil.copy(final_panoramic_image, target_path)
                                logging.info(f"Final stitched image copied to {target_path}")
                            except Exception as e:
                                logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}")
                                logging.error(traceback.format_exc())
                        else:
                            logging.error(f"Failed to perform tree-like vertical stitching in scan: {scan_folder}")

                        # **Cleanup Unused Resources**
                        cleanup_option = config_options.get('cleanup_unused', False)
                        if cleanup_option:
                            logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}")
                            cleanup_unused_resources(
                                scan_path=scan_path,
                                used_horizontal_scans=used_horizontal_scans,
                                used_images_per_scan=used_images_per_scan
                            )
                            logging.info(f"Cleanup completed for scan: {scan_folder}")
                        else:
                            logging.info("Cleanup of unused resources is disabled in the configuration.")

                        # **Cleanup Intermediate Outputs**
                        cleanup_intermediate_option = config_options.get('cleanup_intermediate', False)
                        if cleanup_intermediate_option:
                            logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}")
                            cleanup_intermediate_outputs(scan_path=scan_path)
                            logging.info(f"Intermediate cleanup completed for scan: {scan_folder}")
                        else:
                            logging.info("Cleanup of intermediate outputs is disabled in the configuration.")

                    else:
                        logging.error(f"Unknown stitching mode: {stitching_mode}. Skipping scan: {scan_folder}")
                        continue

                except Exception as scan_error:
                    logging.error(f"An error occurred while processing scan '{scan_folder}': {scan_error}")
                    logging.error(traceback.format_exc())
                    logging.info(f"Skipping scan '{scan_folder}' and continuing with the next one.")
                    continue  # Skip to the next scan

        # Cleanup and Exit
        logging.info("\n========================")
        logging.info("Stitching Process Completed")
        logging.info("========================\n")

        # Clean up PyTorch models and free CUDA memory
        if torch.cuda.is_available():
            warp_net.cpu()
            composition_net.cpu()
            del warp_net
            del composition_net
            torch.cuda.empty_cache()
            logging.info("Cleaned up PyTorch models and cleared CUDA cache.")

        logging.info("Exiting the script.")
        sys.exit(0)  # Ensure the script exits automatically

    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}\n{traceback.format_exc()}")
        sys.exit(1)


if __name__ == "__main__":
    main()
Leave a Comment