Untitled

 avatar
unknown
plain_text
a year ago
10 kB
3
Indexable
import os
import gc
from PIL import Image, ImageDraw, ImageOps, ImageFont
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from tqdm import tqdm
from torch.cuda.amp import autocast
import logging
import numpy as np
import cv2
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import requests
from PIL import Image
from tqdm import tqdm
import os
import json
from tqdm import tqdm
device_map = {
    'vision_tower': 'cuda:0',     # Assign Vision Tower to cuda:0
    'language_model': 'cuda:1',   # Assign Language Model to cuda:1
    'multi_modal_projector': 'cuda:1' # Assign MultiModal Projector to CPU
}
model_id = "google/paligemma-3b-ft-vqav2-448"
quantization_config8bit  = BitsAndBytesConfig(
    load_in_8bit=True, 
    llm_int8_threshold=6.0, # Adjust this threshold based on your needs
    llm_int8_enable_fp32_cpu_offload=True,
    
)
quantization_config4bit = BitsAndBytesConfig(load_in_4bit=True, 
                                         bnb_4bit_compute_dtype=torch.bfloat16,
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4",
                                         llm_int8_enable_fp32_cpu_offload=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config4bit, device_map=device_map)
processor = AutoProcessor.from_pretrained(model_id)
processor_seg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64")
model_seg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64").to("cuda:1")

def draw_segments(image, masks):
    red_image = Image.new('RGB', image.size, (255, 0, 0))
    for mask in masks:
        mask = ImageOps.fit(mask, image.size, method=Image.NEAREST)
        image = Image.composite(red_image, image, mask)
    return image

def get_connected_components(binary_mask):
    binary_mask_np = (binary_mask.cpu().numpy() * 255).astype(np.uint8)
    pil_image = Image.fromarray(binary_mask_np, mode='L')
    pil_image.save("clipsegonlymask.jpg")
    num_labels, labels_im = cv2.connectedComponents(binary_mask_np)
    return num_labels, labels_im

def visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
    overlay = image.copy()
    draw = ImageDraw.Draw(overlay)
    for label in range(1, num_labels):
        component_mask = (labels_im == label).astype(np.uint8)
        x, y, w, h = cv2.boundingRect(component_mask)
        x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
        draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
    return overlay

def crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
    cropped_images = []
    bboxes = []
    for label in range(1, num_labels):
        component_mask = (labels_im == label).astype(np.uint8)
        x, y, w, h = cv2.boundingRect(component_mask)
        x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
        cropped_image = image.crop((x, y, x + w, y + h))
        cropped_images.append(cropped_image)
        bboxes.append([x, y, w, h])
    return cropped_images, bboxes

def segment_and_crop(image_path, text, confidence=0.2):
    # Load and process the image
    image = Image.open(image_path).convert("RGB")
    original_size = image.size
    inputs = processor_seg(text=[text], images=image, return_tensors="pt").to("cuda:1")
    
    # Perform inference
    with torch.inference_mode():
        outputs = model_seg(**inputs)

    # Convert logits to probabilities using sigmoid
    logits = outputs.logits.cpu().detach()
    probs = torch.sigmoid(logits)

    # Apply the confidence threshold to get binary masks
    binary_mask = probs > confidence

    # Find connected components in the binary mask
    num_labels, labels_im = get_connected_components(binary_mask.squeeze(0))

    # Calculate the scale factors between the original image and the processed image
    resized_size = (352, 352)  # Assuming the model resizes the image to 352x352
    scale_factor_w = original_size[0] / resized_size[0]
    scale_factor_h = original_size[1] / resized_size[1]

    # Debug: visualize connected components
    image_with_components = visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h)
    image_with_components.save("checkclipseg_components.jpg")

    # Crop the objects from the image
    cropped_images, bboxes = crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h)
    
    # Clean up
    del inputs
    del outputs
    gc.collect()
    torch.cuda.empty_cache()

    return cropped_images, bboxes

def check_image_description(image, description, model, processor):
    prompt = f"Is the image fitting the description: '{description}'?"
    print(image.size)
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    return 'yes' in decoded.lower()

def process_image(image_path, general_description, detailed_description):
    cropped_images, bboxes = segment_and_crop(image_path, general_description)
    fitting_crops = []
    fitting_bboxes = []
    for cropped_image, bbox in zip(cropped_images, bboxes):
        #print(bbox)
        if (bbox[2] * bbox[3]) < 20 or bbox[2] < 3 or bbox[3] < 3:
            continue
        #cropped_image.save("wtf.jpg")
        #print(type(cropped_image), cropped_image.size)
        if check_image_description(cropped_image, detailed_description, model, processor):
            fitting_crops.append(cropped_image)
            fitting_bboxes.append(bbox)
    return fitting_crops, fitting_bboxes

def draw_boxes_on_image(image, bboxes, detailed_description):
    draw = ImageDraw.Draw(image)
    try:
        # Load a truetype or opentype font file, and set the font size
        font = ImageFont.truetype("arial.ttf", 20)  # You can adjust the font size as needed
    except IOError:
        font = ImageFont.load_default()  # Use the default font if the specified font is not available
    for bbox in bboxes:
        x, y, w, h = bbox
        draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
        draw.text((x, y), detailed_description, fill="red", font=font)
    return image

def process_dataset(dataset_path, description_pairs, output_path):
    os.makedirs(output_path, exist_ok=True)
    crops_path = os.path.join(output_path, "crops")
    os.makedirs(crops_path, exist_ok=True)
    annotated_images_path = os.path.join(output_path, "annotated_images")
    os.makedirs(annotated_images_path, exist_ok=True)
    annotations_path = os.path.join(output_path, "annotations")
    os.makedirs(annotations_path, exist_ok=True)
    
    annotations = {
        "info": {
            "description": "Generated annotations",
            "version": "1.0",
            "year": 2024
        },
        "licenses": [],
        "images": [],
        "annotations": [],
        "categories": []
    }
    category_id = 1
    annotations["categories"].append({
        "id": category_id,
        "name": "damaged pill",
        "supercategory": "pill"
    })
    
    image_id = 0
    annotation_id = 0
    
    for root, _, files in tqdm(os.walk(dataset_path), desc="Processing dataset"):
        for filename in files:
            if not filename.endswith("jpg"):
                continue
            image_path = os.path.join(root, filename)
            for general_description, detailed_description in description_pairs:
                fitting_crops, fitting_bboxes = process_image(image_path, general_description, detailed_description)
                image = Image.open(image_path).convert("RGB")
                
                if fitting_crops:
                    annotations["images"].append({
                        "id": image_id,
                        "file_name": os.path.relpath(image_path, dataset_path),
                        "height": image.height,
                        "width": image.width
                    })
                    
                    for i, (crop, bbox) in enumerate(zip(fitting_crops, fitting_bboxes)):
                        crop_path = os.path.join(crops_path, f"{image_id}_{i}.png")
                        crop.save(crop_path)
                        
                        annotations["annotations"].append({
                            "id": annotation_id,
                            "image_id": image_id,
                            "category_id": category_id,
                            "bbox": bbox,
                            "area": bbox[2] * bbox[3],
                            "segmentation": [],  # Optional, add segmentation if needed
                            "iscrowd": 0
                        })
                        annotation_id += 1
                
                # Always save the annotated image
                annotated_image = draw_boxes_on_image(image.copy(), fitting_bboxes, detailed_description)
                annotated_image_save_path = os.path.join(annotated_images_path, os.path.relpath(image_path, dataset_path))
                os.makedirs(os.path.dirname(annotated_image_save_path), exist_ok=True)
                annotated_image.save(annotated_image_save_path)
                
            image_id += 1
    
    with open(os.path.join(annotations_path, "annotations.json"), "w") as f:
        json.dump(annotations, f, indent=2)
# Example usage
dataset_path = "alldata/gunsss"
description_pairs = [("gun", "gun")]
output_path = "alldata/gunsss_annotated"
process_dataset(dataset_path, description_pairs, output_path)
Editor is loading...
Leave a Comment