Untitled
unknown
plain_text
a year ago
56 kB
15
Indexable
import os
import sys
import argparse
import yaml
import torch
import cv2
import numpy as np
import logging
import rosbag
import shutil
import re
import traceback
import math
from Warp.Codes.network import (
get_stitched_result as warp_get_stitched_result,
Network as WarpNetwork,
build_new_ft_model as warp_build_new_ft_model,
)
from Composition.Codes.network import (
build_model as composition_build_model,
Network as CompositionNetwork,
)
import torchvision.transforms as T
from zz_bagparser_w import (
load_xy_positions_specific,
process_images,
cleanup_redundant_images,
crop_images_specific,
crop_images_specific_vertical
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("stitching_process.log"),
logging.StreamHandler(sys.stdout)
]
)
# GPU Configuration
GPU_DEVICE = "0"
COMPOSITION_MODEL_PATH = os.path.join("Composition", "model", "epoch050_model.pth")
def load_config(config_path):
"""
Load YAML configuration file.
Parameters:
config_path (str): Path to the YAML config file.
Returns:
dict: Configuration parameters.
"""
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
# Stitching Functions
def perform_manual_warping(cropped_img1_path, cropped_img2_path, output_dir):
"""
Manually perform warping by creating masks and blending the cropped images horizontally.
Parameters:
cropped_img1_path (str): Path to the cropped left image.
cropped_img2_path (str): Path to the cropped right image.
output_dir (str): Directory to save warped images and masks.
"""
img1 = cv2.imread(cropped_img1_path)
img2 = cv2.imread(cropped_img2_path)
if img1 is None:
raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}")
if img2 is None:
raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}")
# Generate mask1
height1, width1, _ = img1.shape
split1 = int(width1 * 0.9)
mask1 = np.zeros_like(img1, dtype=np.uint8)
mask1[:, :split1, :] = 255
mask1[:, split1:, :] = 0
# Generate mask2
height2, width2, _ = img2.shape
split2 = int(width2 * 0.1)
mask2 = np.zeros_like(img2, dtype=np.uint8)
mask2[:, :split2, :] = 0
mask2[:, split2:, :] = 255
# Create warped images
warp1 = cv2.bitwise_and(img1, mask1)
warp2 = cv2.bitwise_and(img2, mask2)
# Save masks and warped images
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1)
cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2)
cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1)
cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2)
logging.info(f"Manual warped images and masks saved in {output_dir}")
def perform_manual_warping_vertical(cropped_img1_path, cropped_img2_path, output_dir):
"""
Manually perform warping by creating masks and blending the cropped images vertically.
Parameters:
cropped_img1_path (str): Path to the cropped bottom image.
cropped_img2_path (str): Path to the cropped top image.
output_dir (str): Directory to save warped images and masks.
"""
img1 = cv2.imread(cropped_img1_path)
img2 = cv2.imread(cropped_img2_path)
if img1 is None:
raise FileNotFoundError(f"Cropped Image 1 not found: {cropped_img1_path}")
if img2 is None:
raise FileNotFoundError(f"Cropped Image 2 not found: {cropped_img2_path}")
# Generate mask1
height1, width1, _ = img1.shape
split1 = int(height1 * 0.9)
mask1 = np.zeros_like(img1, dtype=np.uint8)
mask1[:split1, :, :] = 255
mask1[split1:, :, :] = 0
# Generate mask2
height2, width2, _ = img2.shape
split2 = int(height2 * 0.1)
mask2 = np.zeros_like(img2, dtype=np.uint8)
mask2[:split2, :, :] = 0
mask2[split2:, :, :] = 255
# Create warped images
warp1 = cv2.bitwise_and(img1, mask1)
warp2 = cv2.bitwise_and(img2, mask2)
# Save masks and warped images
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(os.path.join(output_dir, "mask1.jpg"), mask1)
cv2.imwrite(os.path.join(output_dir, "mask2.jpg"), mask2)
cv2.imwrite(os.path.join(output_dir, "warp1.jpg"), warp1)
cv2.imwrite(os.path.join(output_dir, "warp2.jpg"), warp2)
logging.info(f"Manual warped images and masks saved in {output_dir}")
def perform_composition(net, input_dir, output_dir, config_options):
"""
Perform composition on the warped outputs to generate the stitched image.
Parameters:
net (CompositionNetwork): The initialized composition network.
input_dir (str): Directory containing warped images and masks.
output_dir (str): Directory to save the composed image.
config_options (dict): Configuration options from config.yaml.
"""
# Load warped images
warp1 = cv2.imread(os.path.join(input_dir, "warp1.jpg"))
warp2 = cv2.imread(os.path.join(input_dir, "warp2.jpg"))
if warp1 is None or warp2 is None:
raise FileNotFoundError("Warped images not found in the input directory.")
# Convert to float32 and normalize
warp1 = warp1.astype(dtype=np.float32)
warp1 = (warp1 / 127.5) - 1.0
warp1 = np.transpose(warp1, [2, 0, 1])
warp1_tensor = torch.tensor(warp1).unsqueeze(0)
warp2 = warp2.astype(dtype=np.float32)
warp2 = (warp2 / 127.5) - 1.0
warp2 = np.transpose(warp2, [2, 0, 1])
warp2_tensor = torch.tensor(warp2).unsqueeze(0)
# Load masks
mask1 = cv2.imread(os.path.join(input_dir, "mask1.jpg"))
mask2 = cv2.imread(os.path.join(input_dir, "mask2.jpg"))
if mask1 is None or mask2 is None:
raise FileNotFoundError("Masks not found in the input directory.")
mask1 = mask1.astype(dtype=np.float32) / 255.0
mask1 = np.transpose(mask1, [2, 0, 1])
mask1_tensor = torch.tensor(mask1).unsqueeze(0)
mask2 = mask2.astype(dtype=np.float32) / 255.0
mask2 = np.transpose(mask2, [2, 0, 1])
mask2_tensor = torch.tensor(mask2).unsqueeze(0)
if torch.cuda.is_available():
warp1_tensor = warp1_tensor.cuda()
warp2_tensor = warp2_tensor.cuda()
mask1_tensor = mask1_tensor.cuda()
mask2_tensor = mask2_tensor.cuda()
net = net.cuda()
# Perform composition
net.eval()
with torch.no_grad():
batch_out = composition_build_model(
net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor
)
stitched_image = batch_out["stitched_image"]
logging.info(f"Stitched image shape: {stitched_image.shape}")
# Save the composed image
os.makedirs(output_dir, exist_ok=True)
stitched_image_np = (
(stitched_image[0] + 1) * 127.5
).cpu().numpy().transpose(1, 2, 0)
cv2.imwrite(
os.path.join(output_dir, "composition.jpg"), stitched_image_np.astype(np.uint8)
)
logging.info(f"Composed image saved in {output_dir}")
# Save learned masks if enabled
if config_options.get('save_learned_masks', False):
learned_mask1 = (batch_out['learned_mask1'][0] * 255).cpu().numpy().transpose(1, 2, 0)
learned_mask2 = (batch_out['learned_mask2'][0] * 255).cpu().numpy().transpose(1, 2, 0)
cv2.imwrite(os.path.join(input_dir, "learn_mask1.jpg"), learned_mask1.astype(np.uint8))
cv2.imwrite(os.path.join(input_dir, "learn_mask2.jpg"), learned_mask2.astype(np.uint8))
logging.info(f"Learned masks saved in {input_dir}/learn_mask1.jpg and {input_dir}/learn_mask2.jpg")
# Save composition_color.jpg if enabled
if config_options.get('save_composition_color', False):
s1 = ((warp1_tensor[0] + 1) * 127.5 * batch_out['learned_mask1'][0]).cpu().detach().numpy().transpose(1, 2, 0)
s2 = ((warp2_tensor[0] + 1) * 127.5 * batch_out['learned_mask2'][0]).cpu().detach().numpy().transpose(1, 2, 0)
fusion = np.zeros((warp1_tensor.shape[2], warp1_tensor.shape[3], 3), np.uint8)
fusion[..., 0] = s2[..., 0]
fusion[..., 1] = (s1[..., 1] * 0.5 + s2[..., 1] * 0.5).astype(np.uint8)
fusion[..., 2] = s1[..., 2]
composition_color_path = os.path.join(output_dir, "composition_color.jpg")
cv2.imwrite(composition_color_path, fusion)
logging.info(f"Composition color image saved at {composition_color_path}")
def stitch_pair(left_image, right_image, tx_multiplier, warp_net, composition_net, output_dir, level, group_num, pair_num, config_options):
"""
Stitch a pair of images horizontally using manual warping and composition.
Parameters:
left_image (str): Path to the left image.
right_image (str): Path to the right image.
tx_multiplier (int): Multiplier for the base translation in the x-direction.
warp_net (WarpNetwork): Initialized warping network (Not used in manual warping).
composition_net (CompositionNetwork): Initialized composition network.
output_dir (str): Directory to save intermediate and final outputs.
level (int): Current stitching level.
group_num (str): Group number for unique identification.
pair_num (int): Pair number within the group.
config_options (dict): Configuration options from config.yaml.
Returns:
str: Path to the stitched image.
"""
base_tx = config_options['tx_level']
tx = tx_multiplier * base_tx
# Logging the images and tx value
logging.info(f"Stitching pair {pair_num} at level {level}, group {group_num}")
logging.info(f"Left Image: {left_image}")
logging.info(f"Right Image: {right_image}")
logging.info(f"Effective tx used: {tx}")
# Perform cropping
cropped = crop_images_specific(left_image, right_image, tx, 0)
cropped_img1 = cropped["cropped_img1"]
cropped_img2 = cropped["cropped_img2"]
unused_img1 = cropped["unused_img1"]
unused_img2 = cropped["unused_img2"]
# Save cropped images for reference
cropped_dir = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}", "cropped")
os.makedirs(cropped_dir, exist_ok=True)
cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1)
cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2)
cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1)
cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2)
logging.info(f"Cropped images saved for level {level} group {group_num} pair {pair_num} in {cropped_dir}")
# Paths to the cropped images
cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg")
cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg")
# Perform manual warping on cropped regions
perform_manual_warping(
cropped_img1_path, cropped_img2_path, cropped_dir
)
# Perform composition on warped outputs
perform_composition(composition_net, cropped_dir, cropped_dir, config_options)
# Load the composed overlapping region
composed_overlap_path = os.path.join(cropped_dir, "composition.jpg")
composed_overlap = cv2.imread(composed_overlap_path)
if composed_overlap is None:
logging.error(f"Composition failed for level {level} group {group_num} pair {pair_num}. Composed image not found at {composed_overlap_path}")
return None
# Concatenate unused parts with composed overlap horizontally
try:
final_image = np.hstack((unused_img1, composed_overlap, unused_img2))
except Exception as e:
logging.error(f"Error during image concatenation for level {level} group {group_num} pair {pair_num}: {e}")
return None
logging.info(f"Final stitched image shape for level {level} group {group_num} pair {pair_num}: {final_image.shape}")
# Save the final stitched image
stitched_image_path = os.path.join(output_dir, f"level{level}_group{group_num}_pair{pair_num}_final.jpg")
os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True)
cv2.imwrite(stitched_image_path, final_image)
logging.info(f'Final stitched image saved for level {level} group {group_num} pair {pair_num} at {stitched_image_path}')
return stitched_image_path
def stitch_pair_vertical(bottom_image, top_image, ty_multiplier, warp_net, composition_net, output_dir, level, pair_num, config_options):
"""
Stitch a pair of images vertically using manual warping and composition.
Parameters:
bottom_image (str): Path to the bottom image.
top_image (str): Path to the top image.
ty_multiplier (int): Multiplier for the base translation in the y-direction.
warp_net (WarpNetwork): Initialized warping network (Not used in manual warping).
composition_net (CompositionNetwork): Initialized composition network.
output_dir (str): Directory to save intermediate and final outputs.
level (int): Current stitching level.
pair_num (int): Pair number within the current level.
config_options (dict): Configuration options from config.yaml.
Returns:
str: Path to the stitched image.
"""
base_ty = config_options['ty_level']
ty = ty_multiplier * base_ty
# Logging the images and ty value
logging.info(f"Stitching pair {pair_num} at level {level}")
logging.info(f"Bottom Image: {bottom_image}")
logging.info(f"Top Image: {top_image}")
logging.info(f"Effective ty used: {ty}")
# Perform cropping
cropped = crop_images_specific_vertical(bottom_image, top_image, ty)
cropped_img1 = cropped["cropped_img1"] # Bottom cropped
cropped_img2 = cropped["cropped_img2"] # Top cropped
unused_img1 = cropped["unused_img1"] # Bottom unused
unused_img2 = cropped["unused_img2"] # Top unused
# Save cropped images for reference
cropped_dir = os.path.join(output_dir, f"level{level}_pair{pair_num}", "cropped")
os.makedirs(cropped_dir, exist_ok=True)
cv2.imwrite(os.path.join(cropped_dir, "cropped_img1.jpg"), cropped_img1)
cv2.imwrite(os.path.join(cropped_dir, "cropped_img2.jpg"), cropped_img2)
cv2.imwrite(os.path.join(cropped_dir, "unused_img1.jpg"), unused_img1)
cv2.imwrite(os.path.join(cropped_dir, "unused_img2.jpg"), unused_img2)
logging.info(f"Cropped images saved for level {level} pair {pair_num} in {cropped_dir}")
# Paths to the cropped images
cropped_img1_path = os.path.join(cropped_dir, "cropped_img1.jpg")
cropped_img2_path = os.path.join(cropped_dir, "cropped_img2.jpg")
# Perform manual warping on cropped regions
perform_manual_warping_vertical(
cropped_img1_path, cropped_img2_path, cropped_dir
)
# Perform composition on warped outputs
perform_composition(composition_net, cropped_dir, cropped_dir, config_options)
# Load the composed overlapping region
composed_overlap_path = os.path.join(cropped_dir, "composition.jpg")
composed_overlap = cv2.imread(composed_overlap_path)
if composed_overlap is None:
logging.error(f"Composition failed for level {level} pair {pair_num}. Composed image not found at {composed_overlap_path}")
return None
# Concatenate unused parts with composed overlap vertically
try:
# Assuming all images have the same width after cropping
final_image = np.vstack((unused_img1, composed_overlap, unused_img2))
except Exception as e:
logging.error(f"Error during image concatenation for level {level} pair {pair_num}: {e}")
return None
logging.info(f"Final stitched image shape for level {level} pair {pair_num}: {final_image.shape}")
# Save the final stitched image
stitched_image_path = os.path.join(output_dir, f"level{level}_pair{pair_num}_final.jpg")
os.makedirs(os.path.dirname(stitched_image_path), exist_ok=True)
cv2.imwrite(stitched_image_path, final_image)
logging.info(f'Final stitched image saved for level {level} pair {pair_num} at {stitched_image_path}')
return stitched_image_path
def tree_stitch_images(
images,
stride_multiplier,
stitch_func,
warp_net,
composition_net,
output_dir,
level,
group_num,
config_options,
direction='horizontal'
):
"""
Perform tree-like stitching on a list of images until only one image remains.
Parameters:
images (list): List of image paths to stitch.
stride_multiplier (int): Stride multiplier based on the configuration (e.g., row_stride).
stitch_func (function): Stitching function to use (stitch_pair or stitch_pair_vertical).
warp_net (WarpNetwork): Warp network instance.
composition_net (CompositionNetwork): Composition network instance.
output_dir (str): Directory to save outputs.
level (int): Starting stitching level.
group_num (str): Group identifier.
config_options (dict): Configuration options.
direction (str): 'horizontal' or 'vertical'.
Returns:
str: Path to the final stitched image.
"""
images_to_stitch = images.copy()
current_level = 0
# Calculate the largest power of two less than or equal to the number of images
num_images = len(images_to_stitch)
if num_images < 2:
logging.warning("Not enough images to stitch.")
return images_to_stitch[0] if images_to_stitch else None
largest_power = 2**int(math.log2(num_images))
selected_images = images_to_stitch[:largest_power]
omitted_images = images_to_stitch[largest_power:]
if omitted_images:
logging.info(f"Omitting {len(omitted_images)} redundant image(s) to make the count a power of two.")
images_to_stitch = selected_images
while len(images_to_stitch) > 1:
stitched_images = []
# Calculate multiplier based on stride and current level
multiplier = stride_multiplier * (2 ** current_level)
logging.info(f"Stitching Level {level + current_level}: Multiplier = {multiplier}")
for i in range(0, len(images_to_stitch), 2):
img1 = images_to_stitch[i]
if i + 1 < len(images_to_stitch):
img2 = images_to_stitch[i + 1]
if direction == 'horizontal':
stitched_image = stitch_func(
left_image=img1,
right_image=img2,
tx_multiplier=multiplier,
warp_net=warp_net,
composition_net=composition_net,
output_dir=output_dir,
level=level + current_level,
group_num=group_num,
pair_num=(i // 2) + 1,
config_options=config_options
)
else:
stitched_image = stitch_func(
bottom_image=img1,
top_image=img2,
ty_multiplier=multiplier,
warp_net=warp_net,
composition_net=composition_net,
output_dir=output_dir,
level=level + current_level,
pair_num=(i // 2) + 1,
config_options=config_options
)
if stitched_image:
stitched_images.append(stitched_image)
else:
logging.error(f"Failed to stitch {img1} and {img2}")
return None
else:
# If odd number of images, carry the last one forward
stitched_images.append(img1)
logging.info(f"Odd image {img1} carried forward to the next level.")
images_to_stitch = stitched_images
current_level += 1
logging.info(f"Completed stitching up to level {level + current_level - 1}")
return images_to_stitch[0] if images_to_stitch else None
def select_power_of_two_rows(rows_to_stitch):
"""
Selects the largest power-of-two number of rows from the given list.
Parameters:
rows_to_stitch (list): List of row indices to stitch.
Returns:
list: Subset of rows_to_stitch with length as a power of two.
"""
num_rows = len(rows_to_stitch)
if num_rows == 0:
return []
largest_power = 2**int(math.log2(num_rows))
selected_rows = rows_to_stitch[:largest_power]
omitted_rows = rows_to_stitch[largest_power:]
if omitted_rows:
logging.info(f"Omitting {len(omitted_rows)} redundant row(s) to make the count a power of two: {selected_rows}")
return selected_rows
def sanitize_filename(filename):
"""
Sanitize the filename by replacing or removing invalid characters.
"""
# Replace spaces with underscores
filename = filename.replace(" ", "_")
# Replace colons with hyphens
filename = filename.replace(":", "-")
# Remove any other invalid characters (for cross-platform compatibility)
# Allow only alphanumerics, underscores, hyphens, and dots
filename = re.sub(r'[^A-Za-z0-9._-]', '', filename)
return filename
def cleanup_unused_resources(scan_path, used_horizontal_scans, used_images_per_scan):
"""
Delete unused horizontal_scanX directories and unused images within used directories.
Parameters:
scan_path (str): Path to the current scan directory.
used_horizontal_scans (set): Set of used horizontal_scanX directory names.
used_images_per_scan (dict): Dictionary mapping horizontal_scanX to set of used image filenames.
"""
# List all horizontal_scanX directories
all_horizontal_scans = [
d for d in os.listdir(scan_path)
if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
]
for hs_dir in all_horizontal_scans:
hs_path = os.path.join(scan_path, hs_dir)
if hs_dir not in used_horizontal_scans:
# Delete the entire directory
try:
shutil.rmtree(hs_path)
logging.info(f"Deleted unused directory: {hs_path}")
except Exception as e:
logging.error(f"Failed to delete directory {hs_path}: {e}")
logging.error(traceback.format_exc())
else:
# Delete unused images within the used horizontal_scanX directory
used_images = used_images_per_scan.get(hs_dir, set())
all_images = [
f for f in os.listdir(hs_path)
if f.lower().endswith('.jpg')
]
for img in all_images:
if img not in used_images:
img_path = os.path.join(hs_path, img)
try:
os.remove(img_path)
logging.info(f"Deleted unused image: {img_path}")
except Exception as e:
logging.error(f"Failed to delete image {img_path}: {e}")
logging.error(traceback.format_exc())
def cleanup_intermediate_outputs(scan_path):
"""
Delete directories and files starting with 'level' (e.g., 'level1', 'level2', etc.) within the scan directory.
Parameters:
scan_path (str): Path to the current scan directory.
"""
# List all items in the scan directory
for item in os.listdir(scan_path):
item_path = os.path.join(scan_path, item)
# Check if the item is a directory starting with 'level'
if os.path.isdir(item_path) and re.match(r'^level\d+', item):
try:
shutil.rmtree(item_path)
logging.info(f"Deleted intermediate directory: {item_path}")
except Exception as e:
logging.error(f"Failed to delete intermediate directory {item_path}: {e}")
logging.error(traceback.format_exc())
# Check if the item is a file starting with 'level'
elif os.path.isfile(item_path) and re.match(r'^level\d+', item):
try:
os.remove(item_path)
logging.info(f"Deleted intermediate file: {item_path}")
except Exception as e:
logging.error(f"Failed to delete intermediate file {item_path}: {e}")
logging.error(traceback.format_exc())
def main():
parser = argparse.ArgumentParser(description="Image Stitching Script")
parser.add_argument(
'--bag_path',
type=str,
required=True,
help='Path to the ROS bag file.'
)
parser.add_argument(
'--output_dir',
type=str,
required=True,
help='Directory to save the output images.'
)
parser.add_argument(
'--config_paths',
type=str,
required=True,
nargs='+',
help='Paths to the YAML configuration files for each hive and XY table.'
)
args = parser.parse_args()
# Convert to absolute paths
BAG_PATH = os.path.abspath(args.bag_path)
OUTPUT_BASE_DIR = os.path.abspath(args.output_dir)
CONFIG_PATHS = [os.path.abspath(path) for path in args.config_paths]
# Define 'Final_Outputs' directory
final_outputs_dir = os.path.join(OUTPUT_BASE_DIR, "Final_Outputs")
os.makedirs(final_outputs_dir, exist_ok=True)
logging.info(f"Final outputs will be saved in {final_outputs_dir}")
# Load and validate each configuration
configs = []
for config_path in CONFIG_PATHS:
if not os.path.exists(config_path):
logging.error(f"Configuration file not found at {config_path}")
sys.exit(1)
config = load_config(config_path)
# Validate required fields
if 'hive_id' not in config or 'xy_id' not in config:
logging.error(f"Configuration file {config_path} missing 'hive_id' or 'xy_id'.")
sys.exit(1)
configs.append(config)
logging.info(f"Loaded configuration from {config_path}")
try:
# Environment and Network Initialization
# Set CUDA environment variables
os.environ["CUDA_DEVICES_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_DEVICE
# Add project root to sys.path
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
logging.info("Initializing WarpNetwork and CompositionNetwork...")
# Initialize networks
warp_net = WarpNetwork()
composition_net = CompositionNetwork()
logging.info("Networks initialized successfully.")
# Load pre-trained models for composition_net
if os.path.exists(COMPOSITION_MODEL_PATH):
logging.info(f"Loading composition model from {COMPOSITION_MODEL_PATH}...")
checkpoint = torch.load(
COMPOSITION_MODEL_PATH,
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)
composition_net.load_state_dict(checkpoint["model"])
logging.info(f"Loaded composition model from {COMPOSITION_MODEL_PATH}")
else:
logging.error(f"No composition model found at {COMPOSITION_MODEL_PATH}")
sys.exit(1)
# Move networks to GPU if available
if torch.cuda.is_available():
warp_net = warp_net.cuda()
composition_net = composition_net.cuda()
logging.info("Networks moved to GPU.")
else:
logging.info("CUDA not available. Networks are on CPU.")
# Bag Parsing and Image Processing
logging.info("\n========================")
logging.info("Starting Bag Parsing and Processing")
logging.info("========================\n")
for config in configs:
hive_id = config.get('hive_id')
xy_id = config.get('xy_id')
logging.info(f"\n=== Processing Hive {hive_id}, XY Table {xy_id} ===")
# Define paths specific to hive and XY table
hive_xy_output_dir = os.path.join(OUTPUT_BASE_DIR, f"hive_{hive_id}", f"xy_{xy_id}")
os.makedirs(hive_xy_output_dir, exist_ok=True)
logging.info(f"Output directory for Hive {hive_id}, XY {xy_id}: {hive_xy_output_dir}")
# Extract configurations from each config file
stitching_mode = config.get('stitching', {}).get('mode', 'linear').lower()
translations = config.get('translations', {})
tx_level = translations.get('tx_level', -320)
ty_level = translations.get('ty_level', -330)
config_options = config.get('options', {})
config_options['tx_level'] = tx_level
config_options['ty_level'] = ty_level
# Retrieve 'resize_shape' from config
resize_shape = config_options.get('resize_shape', None)
if resize_shape is not None:
if isinstance(resize_shape, list) and len(resize_shape) == 2:
resize_shape = tuple(resize_shape)
logging.info(f"Resize shape set to: {resize_shape}")
else:
logging.warning(f"Invalid resize_shape format in config: {resize_shape}. Setting to None.")
resize_shape = None
config_options['resize_shape'] = resize_shape
# Extract max_images_to_process from stitching
MAX_IMAGES_TO_PROCESS = config.get('stitching', {}).get('max_images_to_process', 120)
if stitching_mode == 'linear':
linear_params = config.get('stitching', {}).get('linear', {})
COLUMNS_TO_STITCH = linear_params.get('columns_to_stitch', [])
ROWS_TO_STITCH = linear_params.get('rows_to_stitch', [])
elif stitching_mode == 'tree':
tree_params = config.get('stitching', {}).get('tree', {})
column_params = tree_params.get('columns', {})
row_params = tree_params.get('rows', {})
else:
logging.error(f"Invalid stitching mode: {stitching_mode}. Choose 'linear' or 'tree'.")
continue # Skip this config and proceed to the next
# Load XY positions for this hive and XY table
xy_positions = load_xy_positions_specific(BAG_PATH, hive_id, xy_id)
if not xy_positions:
logging.warning(f"No XY positions found for Hive {hive_id}, XY {xy_id}. Skipping.")
continue
# Process images from the bag for this hive and XY table
process_images(
xy_positions,
BAG_PATH,
output_dir=hive_xy_output_dir,
threshold=0.001,
significant_move=0.1,
resize_shape=config_options.get('resize_shape'), # Pass from config
max_images=MAX_IMAGES_TO_PROCESS,
apply_color_filter=config_options.get('apply_color_filter', False),
compressed_topic=f"/hive_{hive_id}/xy_{xy_id}/bee_camera/compressed",
xy_position_topic=f"/hive_{hive_id}/xy_{xy_id}/xy_position"
)
logging.info("\n========================")
logging.info("Starting Cleanup Phase")
logging.info("========================\n")
cleanup_redundant_images(hive_xy_output_dir)
logging.info(f"\nAll images processed and saved in the directory: {hive_xy_output_dir}")
logging.info("\n========================")
logging.info("Starting Stitching Phase")
logging.info("========================\n")
# Iterate over each scan directory within the hive and XY table
for scan_folder in os.listdir(hive_xy_output_dir):
scan_path = os.path.join(hive_xy_output_dir, scan_folder)
if not os.path.isdir(scan_path):
logging.info(f"Skipping non-directory: {scan_path}")
continue
try:
logging.info(f"\n--- Stitching Scan: {scan_folder} ---")
if stitching_mode == 'linear':
# ===============================
# Linear Stitching Mode
# ===============================
logging.info("Mode: Linear Stitching")
logging.info(f"Linear Stitching Mode: Columns to stitch: {COLUMNS_TO_STITCH}")
logging.info(f"Linear Stitching Mode: Rows to stitch: {ROWS_TO_STITCH}")
# Find all horizontal_scanX subfolders
horizontal_scan_dirs = sorted([
d for d in os.listdir(scan_path)
if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
], key=lambda x: int(x.replace("horizontal_scan", ""))) # Sort numerically
if not horizontal_scan_dirs:
logging.info(f"No 'horizontal_scanX' folders found in {scan_path}. Skipping stitching.")
continue
# Select only the specified rows for linear stitching
selected_horizontal_scan_dirs = [
d for d in horizontal_scan_dirs
if int(d.replace("horizontal_scan", "")) in ROWS_TO_STITCH
]
if not selected_horizontal_scan_dirs:
logging.warning(f"No 'horizontal_scanX' folders match ROWS_TO_STITCH {ROWS_TO_STITCH} in {scan_path}. Skipping stitching.")
continue
horizontal_stitched_paths = []
# Dictionaries to track used horizontal_scanX directories and their used images
used_horizontal_scans = set()
used_images_per_scan = {}
for horizontal_scan_dir in selected_horizontal_scan_dirs:
images_dir = os.path.join(scan_path, horizontal_scan_dir)
scan_images = sorted([
f for f in os.listdir(images_dir)
if f.lower().endswith('.jpg')
], key=lambda x: int(x.split('_')[1].split('.')[0])) # Sort by image number
logging.info(f"Processing {horizontal_scan_dir} with {len(scan_images)} images.")
if len(scan_images) < 2:
logging.warning(f"Warning: '{horizontal_scan_dir}' does not contain enough images for stitching. Skipping this horizontal scan.")
continue
# Group images based on COLUMNS_TO_STITCH
try:
selected_columns = COLUMNS_TO_STITCH
# Validate column indices
for col in selected_columns:
if col < 1 or col > len(scan_images):
raise IndexError(f"Column index {col} is out of range for the current horizontal scan with {len(scan_images)} images.")
# Sort the selected columns to ensure correct order
selected_columns_sorted = sorted(selected_columns)
selected_images_sorted = [os.path.join(images_dir, scan_images[col-1]) for col in selected_columns_sorted]
except Exception as e:
logging.error(f"Error selecting columns to stitch in '{horizontal_scan_dir}': {e}")
continue
# Track used horizontal_scanX directories and images
hs_dir_name = horizontal_scan_dir
used_horizontal_scans.add(hs_dir_name)
if hs_dir_name not in used_images_per_scan:
used_images_per_scan[hs_dir_name] = set()
for img_path in selected_images_sorted:
img_filename = os.path.basename(img_path)
used_images_per_scan[hs_dir_name].add(img_filename)
# Stitch selected images sequentially with dynamic tx_multiplier
current_stitched = selected_images_sorted[0]
for idx in range(1, len(selected_images_sorted)):
previous_column = selected_columns_sorted[idx-1]
current_column = selected_columns_sorted[idx]
difference = current_column - previous_column # Calculate column difference
tx_multiplier = difference # Set tx_multiplier based on column difference
pair_num = idx # Pair number within the group
logging.info(f"Stitching pair {pair_num}: Column {previous_column} and Column {current_column} with tx_multiplier={tx_multiplier}")
stitched_image = stitch_pair(
left_image=current_stitched,
right_image=selected_images_sorted[idx],
tx_multiplier=tx_multiplier, # Dynamic multiplier based on column difference
warp_net=warp_net,
composition_net=composition_net,
output_dir=scan_path,
level=1, # Level 1 for horizontal stitching
group_num=horizontal_scan_dir.replace("horizontal_scan", ""), # Numeric identifier
pair_num=pair_num, # Pair number within the group
config_options=config_options # Updated config_options with tx_level and ty_level
)
if stitched_image:
current_stitched = stitched_image
logging.info(f"Updated current_stitched to: {current_stitched}")
else:
logging.error(f"Failed to stitch image pair {current_stitched} and {selected_images_sorted[idx]} in '{horizontal_scan_dir}'.")
break
if stitched_image:
horizontal_stitched_paths.append(stitched_image)
else:
logging.error(f"Horizontal stitching failed for '{horizontal_scan_dir}' in scan: {scan_folder}")
if not horizontal_stitched_paths:
logging.error(f"No horizontal stitching results available for scan: {scan_folder}. Skipping vertical stitching.")
continue
logging.info(f"Starting vertical stitching for scan: {scan_folder}")
# Initialize current_panoramic with the first stitched image
current_panoramic = horizontal_stitched_paths[0]
logging.info(f"Initialized current_panoramic with: {current_panoramic}")
# Iterate over the remaining stitched images
for i in range(1, len(horizontal_stitched_paths)):
previous_row = ROWS_TO_STITCH[i-1]
current_row = ROWS_TO_STITCH[i]
row_diff = current_row - ROWS_TO_STITCH[0] # Maintain original row_diff calculation
ty_multiplier = row_diff # Set ty_multiplier based on row difference
pair_num = i # Pair number within vertical stitching
logging.info(f"Stitching vertically: {current_panoramic} with {horizontal_stitched_paths[i]} (ty_multiplier={ty_multiplier})")
stitched_image = stitch_pair_vertical(
bottom_image=current_panoramic,
top_image=horizontal_stitched_paths[i],
ty_multiplier=ty_multiplier, # Dynamic multiplier based on row difference
warp_net=warp_net,
composition_net=composition_net,
output_dir=scan_path,
level=2, # Level 2 for vertical stitching
pair_num=pair_num, # Pair number within vertical stitching
config_options=config_options
)
if stitched_image:
current_panoramic = stitched_image # Update for cumulative stitching
logging.info(f"Updated current_panoramic to: {current_panoramic}")
else:
logging.error(f"Failed to vertically stitch image pair {current_panoramic} and {horizontal_stitched_paths[i]} in scan: {scan_folder}.")
break
# Save the final panoramic image
final_panoramic_stitched = current_panoramic
logging.info(f"\n--- Stitching Completed for Scan: {scan_folder} ---")
logging.info(f"Final vertically stitched (panoramic) image saved at: {final_panoramic_stitched}")
# **Save the Final Stitched Image to Final_Outputs**
scan_folder_name = scan_folder # e.g., "2024-11-23 21-48-23_298"
sanitized_scan_folder_name = sanitize_filename(scan_folder_name)
new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg"
target_path = os.path.join(final_outputs_dir, new_filename)
try:
shutil.copy(final_panoramic_stitched, target_path)
logging.info(f"Final stitched image copied to {target_path}")
except Exception as e:
logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}")
logging.error(traceback.format_exc())
# **Cleanup Unused Resources**
cleanup_option = config_options.get('cleanup_unused', False)
if cleanup_option:
logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}")
cleanup_unused_resources(
scan_path=scan_path,
used_horizontal_scans=used_horizontal_scans,
used_images_per_scan=used_images_per_scan
)
logging.info(f"Cleanup completed for scan: {scan_folder}")
else:
logging.info("Cleanup of unused resources is disabled in the configuration.")
# **Cleanup Intermediate Outputs**
cleanup_intermediate_option = config_options.get('cleanup_intermediate', False)
if cleanup_intermediate_option:
logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}")
cleanup_intermediate_outputs(scan_path=scan_path)
logging.info(f"Intermediate cleanup completed for scan: {scan_folder}")
else:
logging.info("Cleanup of intermediate outputs is disabled in the configuration.")
elif stitching_mode == 'tree':
# ===============================
# Tree-like Stitching Mode
# ===============================
logging.info("Mode: Tree-like Stitching")
logging.info(f"Tree-like Stitching Mode: Columns Params: {column_params}")
logging.info(f"Tree-like Stitching Mode: Rows Params: {row_params}")
# Generate rows to stitch based on start, end, row_stride
start_row = row_params.get('start', 1)
end_row = row_params.get('end', len([
d for d in os.listdir(scan_path)
if d.startswith("horizontal_scan") and os.path.isdir(os.path.join(scan_path, d))
]))
row_stride = row_params.get('row_stride', 1)
rows_to_stitch = list(range(start_row, end_row + 1, row_stride))
# Adjust to power-of-two number of rows
rows_to_stitch_power_two = select_power_of_two_rows(rows_to_stitch)
# Collect image paths based on columns and rows
selected_images = []
# Dictionaries to track used horizontal_scanX directories and their used images
used_horizontal_scans = set()
used_images_per_scan = {}
for row in rows_to_stitch_power_two:
horizontal_scan_dir = f"horizontal_scan{row}"
images_dir = os.path.join(scan_path, horizontal_scan_dir)
if not os.path.isdir(images_dir):
logging.warning(f"Directory {images_dir} does not exist. Skipping.")
continue
scan_images = sorted([
f for f in os.listdir(images_dir)
if f.lower().endswith('.jpg')
], key=lambda x: int(x.split('_')[1].split('.')[0])) # Sort by image number
# Select columns based on start, end, col_stride
start_col = column_params.get('start', 1)
end_col = column_params.get('end', len(scan_images))
col_stride = column_params.get('col_stride', 1)
columns_to_stitch = list(range(start_col, end_col + 1, col_stride))
logging.info(f"Row {row}: Columns to stitch: {columns_to_stitch}")
selected_columns_sorted = sorted(columns_to_stitch)
selected_images_sorted = [
os.path.join(images_dir, scan_images[col-1])
for col in selected_columns_sorted
if 1 <= col <= len(scan_images)
]
if not selected_images_sorted:
logging.warning(f"No valid columns found for row {row}. Skipping.")
continue
# Track used horizontal_scanX directories and images
hs_dir_name = horizontal_scan_dir
used_horizontal_scans.add(hs_dir_name)
if hs_dir_name not in used_images_per_scan:
used_images_per_scan[hs_dir_name] = set()
for img_path in selected_images_sorted:
img_filename = os.path.basename(img_path)
used_images_per_scan[hs_dir_name].add(img_filename)
# Perform tree-like stitching on selected images
logging.info(f"Performing tree-like stitching on row {row} with {len(selected_images_sorted)} images.")
stitched_row_image = tree_stitch_images(
images=selected_images_sorted,
stride_multiplier=col_stride, # Use col_stride as stride_multiplier
stitch_func=stitch_pair,
warp_net=warp_net,
composition_net=composition_net,
output_dir=scan_path,
level=1, # Starting level for horizontal stitching
group_num=f"row{row}",
config_options=config_options,
direction='horizontal'
)
if stitched_row_image:
selected_images.append(stitched_row_image)
else:
logging.error(f"Tree-like stitching failed for row {row}.")
continue
if len(selected_images) < 2:
logging.warning(f"Not enough stitched rows to perform vertical stitching in scan {scan_folder}. Skipping.")
continue
# Perform tree-like vertical stitching on the stitched rows
logging.info("Performing tree-like stitching vertically on stitched rows.")
final_panoramic_image = tree_stitch_images(
images=selected_images,
stride_multiplier=row_stride, # Use row_stride as stride_multiplier
stitch_func=stitch_pair_vertical,
warp_net=warp_net,
composition_net=composition_net,
output_dir=scan_path,
level=1, # Starting level for vertical stitching
group_num="vertical",
config_options=config_options,
direction='vertical'
)
if final_panoramic_image:
logging.info(f"Final tree-like stitched panoramic image saved at: {final_panoramic_image}")
# **Save the Final Stitched Image to Final_Outputs**
scan_folder_name = scan_folder # e.g., "2024-11-23 21-48-23_298"
sanitized_scan_folder_name = sanitize_filename(scan_folder_name)
new_filename = f"{sanitized_scan_folder_name}_hive{hive_id}_xy{xy_id}.jpg"
target_path = os.path.join(final_outputs_dir, new_filename)
try:
shutil.copy(final_panoramic_image, target_path)
logging.info(f"Final stitched image copied to {target_path}")
except Exception as e:
logging.error(f"Failed to copy final stitched image to Final_Outputs: {e}")
logging.error(traceback.format_exc())
else:
logging.error(f"Failed to perform tree-like vertical stitching in scan: {scan_folder}")
# **Cleanup Unused Resources**
cleanup_option = config_options.get('cleanup_unused', False)
if cleanup_option:
logging.info(f"Starting cleanup of unused resources for scan: {scan_folder}")
cleanup_unused_resources(
scan_path=scan_path,
used_horizontal_scans=used_horizontal_scans,
used_images_per_scan=used_images_per_scan
)
logging.info(f"Cleanup completed for scan: {scan_folder}")
else:
logging.info("Cleanup of unused resources is disabled in the configuration.")
# **Cleanup Intermediate Outputs**
cleanup_intermediate_option = config_options.get('cleanup_intermediate', False)
if cleanup_intermediate_option:
logging.info(f"Starting cleanup of intermediate outputs for scan: {scan_folder}")
cleanup_intermediate_outputs(scan_path=scan_path)
logging.info(f"Intermediate cleanup completed for scan: {scan_folder}")
else:
logging.info("Cleanup of intermediate outputs is disabled in the configuration.")
else:
logging.error(f"Unknown stitching mode: {stitching_mode}. Skipping scan: {scan_folder}")
continue
except Exception as scan_error:
logging.error(f"An error occurred while processing scan '{scan_folder}': {scan_error}")
logging.error(traceback.format_exc())
logging.info(f"Skipping scan '{scan_folder}' and continuing with the next one.")
continue # Skip to the next scan
# Cleanup and Exit
logging.info("\n========================")
logging.info("Stitching Process Completed")
logging.info("========================\n")
# Clean up PyTorch models and free CUDA memory
if torch.cuda.is_available():
warp_net.cpu()
composition_net.cpu()
del warp_net
del composition_net
torch.cuda.empty_cache()
logging.info("Cleaned up PyTorch models and cleared CUDA cache.")
logging.info("Exiting the script.")
sys.exit(0) # Ensure the script exits automatically
except Exception as e:
logging.error(f"An unexpected error occurred: {e}\n{traceback.format_exc()}")
sys.exit(1)
if __name__ == "__main__":
main()
Editor is loading...
Leave a Comment