Untitled
unknown
plain_text
a year ago
10 kB
3
Indexable
import os import gc from PIL import Image, ImageDraw, ImageOps, ImageFont import torch from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from tqdm import tqdm from torch.cuda.amp import autocast import logging import numpy as np import cv2 import torch from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig import requests from PIL import Image from tqdm import tqdm import os import json from tqdm import tqdm device_map = { 'vision_tower': 'cuda:0', # Assign Vision Tower to cuda:0 'language_model': 'cuda:1', # Assign Language Model to cuda:1 'multi_modal_projector': 'cuda:1' # Assign MultiModal Projector to CPU } model_id = "google/paligemma-3b-ft-vqav2-448" quantization_config8bit = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, # Adjust this threshold based on your needs llm_int8_enable_fp32_cpu_offload=True, ) quantization_config4bit = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_enable_fp32_cpu_offload=True) model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config4bit, device_map=device_map) processor = AutoProcessor.from_pretrained(model_id) processor_seg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64") model_seg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64").to("cuda:1") def draw_segments(image, masks): red_image = Image.new('RGB', image.size, (255, 0, 0)) for mask in masks: mask = ImageOps.fit(mask, image.size, method=Image.NEAREST) image = Image.composite(red_image, image, mask) return image def get_connected_components(binary_mask): binary_mask_np = (binary_mask.cpu().numpy() * 255).astype(np.uint8) pil_image = Image.fromarray(binary_mask_np, mode='L') pil_image.save("clipsegonlymask.jpg") num_labels, labels_im = cv2.connectedComponents(binary_mask_np) return num_labels, labels_im def visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h): overlay = image.copy() draw = ImageDraw.Draw(overlay) for label in range(1, num_labels): component_mask = (labels_im == label).astype(np.uint8) x, y, w, h = cv2.boundingRect(component_mask) x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h) draw.rectangle([x, y, x + w, y + h], outline="red", width=2) return overlay def crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h): cropped_images = [] bboxes = [] for label in range(1, num_labels): component_mask = (labels_im == label).astype(np.uint8) x, y, w, h = cv2.boundingRect(component_mask) x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h) cropped_image = image.crop((x, y, x + w, y + h)) cropped_images.append(cropped_image) bboxes.append([x, y, w, h]) return cropped_images, bboxes def segment_and_crop(image_path, text, confidence=0.2): # Load and process the image image = Image.open(image_path).convert("RGB") original_size = image.size inputs = processor_seg(text=[text], images=image, return_tensors="pt").to("cuda:1") # Perform inference with torch.inference_mode(): outputs = model_seg(**inputs) # Convert logits to probabilities using sigmoid logits = outputs.logits.cpu().detach() probs = torch.sigmoid(logits) # Apply the confidence threshold to get binary masks binary_mask = probs > confidence # Find connected components in the binary mask num_labels, labels_im = get_connected_components(binary_mask.squeeze(0)) # Calculate the scale factors between the original image and the processed image resized_size = (352, 352) # Assuming the model resizes the image to 352x352 scale_factor_w = original_size[0] / resized_size[0] scale_factor_h = original_size[1] / resized_size[1] # Debug: visualize connected components image_with_components = visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h) image_with_components.save("checkclipseg_components.jpg") # Crop the objects from the image cropped_images, bboxes = crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h) # Clean up del inputs del outputs gc.collect() torch.cuda.empty_cache() return cropped_images, bboxes def check_image_description(image, description, model, processor): prompt = f"Is the image fitting the description: '{description}'?" print(image.size) model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True) return 'yes' in decoded.lower() def process_image(image_path, general_description, detailed_description): cropped_images, bboxes = segment_and_crop(image_path, general_description) fitting_crops = [] fitting_bboxes = [] for cropped_image, bbox in zip(cropped_images, bboxes): #print(bbox) if (bbox[2] * bbox[3]) < 20 or bbox[2] < 3 or bbox[3] < 3: continue #cropped_image.save("wtf.jpg") #print(type(cropped_image), cropped_image.size) if check_image_description(cropped_image, detailed_description, model, processor): fitting_crops.append(cropped_image) fitting_bboxes.append(bbox) return fitting_crops, fitting_bboxes def draw_boxes_on_image(image, bboxes, detailed_description): draw = ImageDraw.Draw(image) try: # Load a truetype or opentype font file, and set the font size font = ImageFont.truetype("arial.ttf", 20) # You can adjust the font size as needed except IOError: font = ImageFont.load_default() # Use the default font if the specified font is not available for bbox in bboxes: x, y, w, h = bbox draw.rectangle([x, y, x + w, y + h], outline="red", width=2) draw.text((x, y), detailed_description, fill="red", font=font) return image def process_dataset(dataset_path, description_pairs, output_path): os.makedirs(output_path, exist_ok=True) crops_path = os.path.join(output_path, "crops") os.makedirs(crops_path, exist_ok=True) annotated_images_path = os.path.join(output_path, "annotated_images") os.makedirs(annotated_images_path, exist_ok=True) annotations_path = os.path.join(output_path, "annotations") os.makedirs(annotations_path, exist_ok=True) annotations = { "info": { "description": "Generated annotations", "version": "1.0", "year": 2024 }, "licenses": [], "images": [], "annotations": [], "categories": [] } category_id = 1 annotations["categories"].append({ "id": category_id, "name": "damaged pill", "supercategory": "pill" }) image_id = 0 annotation_id = 0 for root, _, files in tqdm(os.walk(dataset_path), desc="Processing dataset"): for filename in files: if not filename.endswith("jpg"): continue image_path = os.path.join(root, filename) for general_description, detailed_description in description_pairs: fitting_crops, fitting_bboxes = process_image(image_path, general_description, detailed_description) image = Image.open(image_path).convert("RGB") if fitting_crops: annotations["images"].append({ "id": image_id, "file_name": os.path.relpath(image_path, dataset_path), "height": image.height, "width": image.width }) for i, (crop, bbox) in enumerate(zip(fitting_crops, fitting_bboxes)): crop_path = os.path.join(crops_path, f"{image_id}_{i}.png") crop.save(crop_path) annotations["annotations"].append({ "id": annotation_id, "image_id": image_id, "category_id": category_id, "bbox": bbox, "area": bbox[2] * bbox[3], "segmentation": [], # Optional, add segmentation if needed "iscrowd": 0 }) annotation_id += 1 # Always save the annotated image annotated_image = draw_boxes_on_image(image.copy(), fitting_bboxes, detailed_description) annotated_image_save_path = os.path.join(annotated_images_path, os.path.relpath(image_path, dataset_path)) os.makedirs(os.path.dirname(annotated_image_save_path), exist_ok=True) annotated_image.save(annotated_image_save_path) image_id += 1 with open(os.path.join(annotations_path, "annotations.json"), "w") as f: json.dump(annotations, f, indent=2) # Example usage dataset_path = "alldata/gunsss" description_pairs = [("gun", "gun")] output_path = "alldata/gunsss_annotated" process_dataset(dataset_path, description_pairs, output_path)
Editor is loading...
Leave a Comment