mail@pastecode.io avatar
11 days ago
10 kB
import os
import gc
from PIL import Image, ImageDraw, ImageOps, ImageFont
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from tqdm import tqdm
from torch.cuda.amp import autocast
import logging
import numpy as np
import cv2
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import requests
from PIL import Image
from tqdm import tqdm
import os
import json
from tqdm import tqdm
device_map = {
    'vision_tower': 'cuda:0',     # Assign Vision Tower to cuda:0
    'language_model': 'cuda:1',   # Assign Language Model to cuda:1
    'multi_modal_projector': 'cuda:1' # Assign MultiModal Projector to CPU
model_id = "google/paligemma-3b-ft-vqav2-448"
quantization_config8bit  = BitsAndBytesConfig(
    llm_int8_threshold=6.0, # Adjust this threshold based on your needs
quantization_config4bit = BitsAndBytesConfig(load_in_4bit=True, 
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config4bit, device_map=device_map)
processor = AutoProcessor.from_pretrained(model_id)
processor_seg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64")
model_seg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64").to("cuda:1")

def draw_segments(image, masks):
    red_image = Image.new('RGB', image.size, (255, 0, 0))
    for mask in masks:
        mask = ImageOps.fit(mask, image.size, method=Image.NEAREST)
        image = Image.composite(red_image, image, mask)
    return image

def get_connected_components(binary_mask):
    binary_mask_np = (binary_mask.cpu().numpy() * 255).astype(np.uint8)
    pil_image = Image.fromarray(binary_mask_np, mode='L')
    num_labels, labels_im = cv2.connectedComponents(binary_mask_np)
    return num_labels, labels_im

def visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
    overlay = image.copy()
    draw = ImageDraw.Draw(overlay)
    for label in range(1, num_labels):
        component_mask = (labels_im == label).astype(np.uint8)
        x, y, w, h = cv2.boundingRect(component_mask)
        x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
        draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
    return overlay

def crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
    cropped_images = []
    bboxes = []
    for label in range(1, num_labels):
        component_mask = (labels_im == label).astype(np.uint8)
        x, y, w, h = cv2.boundingRect(component_mask)
        x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
        cropped_image = image.crop((x, y, x + w, y + h))
        bboxes.append([x, y, w, h])
    return cropped_images, bboxes

def segment_and_crop(image_path, text, confidence=0.2):
    # Load and process the image
    image = Image.open(image_path).convert("RGB")
    original_size = image.size
    inputs = processor_seg(text=[text], images=image, return_tensors="pt").to("cuda:1")
    # Perform inference
    with torch.inference_mode():
        outputs = model_seg(**inputs)

    # Convert logits to probabilities using sigmoid
    logits = outputs.logits.cpu().detach()
    probs = torch.sigmoid(logits)

    # Apply the confidence threshold to get binary masks
    binary_mask = probs > confidence

    # Find connected components in the binary mask
    num_labels, labels_im = get_connected_components(binary_mask.squeeze(0))

    # Calculate the scale factors between the original image and the processed image
    resized_size = (352, 352)  # Assuming the model resizes the image to 352x352
    scale_factor_w = original_size[0] / resized_size[0]
    scale_factor_h = original_size[1] / resized_size[1]

    # Debug: visualize connected components
    image_with_components = visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h)

    # Crop the objects from the image
    cropped_images, bboxes = crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h)
    # Clean up
    del inputs
    del outputs

    return cropped_images, bboxes

def check_image_description(image, description, model, processor):
    prompt = f"Is the image fitting the description: '{description}'?"
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    return 'yes' in decoded.lower()

def process_image(image_path, general_description, detailed_description):
    cropped_images, bboxes = segment_and_crop(image_path, general_description)
    fitting_crops = []
    fitting_bboxes = []
    for cropped_image, bbox in zip(cropped_images, bboxes):
        if (bbox[2] * bbox[3]) < 20 or bbox[2] < 3 or bbox[3] < 3:
        #print(type(cropped_image), cropped_image.size)
        if check_image_description(cropped_image, detailed_description, model, processor):
    return fitting_crops, fitting_bboxes

def draw_boxes_on_image(image, bboxes, detailed_description):
    draw = ImageDraw.Draw(image)
        # Load a truetype or opentype font file, and set the font size
        font = ImageFont.truetype("arial.ttf", 20)  # You can adjust the font size as needed
    except IOError:
        font = ImageFont.load_default()  # Use the default font if the specified font is not available
    for bbox in bboxes:
        x, y, w, h = bbox
        draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
        draw.text((x, y), detailed_description, fill="red", font=font)
    return image

def process_dataset(dataset_path, description_pairs, output_path):
    os.makedirs(output_path, exist_ok=True)
    crops_path = os.path.join(output_path, "crops")
    os.makedirs(crops_path, exist_ok=True)
    annotated_images_path = os.path.join(output_path, "annotated_images")
    os.makedirs(annotated_images_path, exist_ok=True)
    annotations_path = os.path.join(output_path, "annotations")
    os.makedirs(annotations_path, exist_ok=True)
    annotations = {
        "info": {
            "description": "Generated annotations",
            "version": "1.0",
            "year": 2024
        "licenses": [],
        "images": [],
        "annotations": [],
        "categories": []
    category_id = 1
        "id": category_id,
        "name": "damaged pill",
        "supercategory": "pill"
    image_id = 0
    annotation_id = 0
    for root, _, files in tqdm(os.walk(dataset_path), desc="Processing dataset"):
        for filename in files:
            if not filename.endswith("jpg"):
            image_path = os.path.join(root, filename)
            for general_description, detailed_description in description_pairs:
                fitting_crops, fitting_bboxes = process_image(image_path, general_description, detailed_description)
                image = Image.open(image_path).convert("RGB")
                if fitting_crops:
                        "id": image_id,
                        "file_name": os.path.relpath(image_path, dataset_path),
                        "height": image.height,
                        "width": image.width
                    for i, (crop, bbox) in enumerate(zip(fitting_crops, fitting_bboxes)):
                        crop_path = os.path.join(crops_path, f"{image_id}_{i}.png")
                            "id": annotation_id,
                            "image_id": image_id,
                            "category_id": category_id,
                            "bbox": bbox,
                            "area": bbox[2] * bbox[3],
                            "segmentation": [],  # Optional, add segmentation if needed
                            "iscrowd": 0
                        annotation_id += 1
                # Always save the annotated image
                annotated_image = draw_boxes_on_image(image.copy(), fitting_bboxes, detailed_description)
                annotated_image_save_path = os.path.join(annotated_images_path, os.path.relpath(image_path, dataset_path))
                os.makedirs(os.path.dirname(annotated_image_save_path), exist_ok=True)
            image_id += 1
    with open(os.path.join(annotations_path, "annotations.json"), "w") as f:
        json.dump(annotations, f, indent=2)
# Example usage
dataset_path = "alldata/gunsss"
description_pairs = [("gun", "gun")]
output_path = "alldata/gunsss_annotated"
process_dataset(dataset_path, description_pairs, output_path)
Leave a Comment