Untitled
unknown
plain_text
a year ago
10 kB
14
Indexable
import os
import gc
from PIL import Image, ImageDraw, ImageOps, ImageFont
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from tqdm import tqdm
from torch.cuda.amp import autocast
import logging
import numpy as np
import cv2
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import requests
from PIL import Image
from tqdm import tqdm
import os
import json
from tqdm import tqdm
device_map = {
'vision_tower': 'cuda:0', # Assign Vision Tower to cuda:0
'language_model': 'cuda:1', # Assign Language Model to cuda:1
'multi_modal_projector': 'cuda:1' # Assign MultiModal Projector to CPU
}
model_id = "google/paligemma-3b-ft-vqav2-448"
quantization_config8bit = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0, # Adjust this threshold based on your needs
llm_int8_enable_fp32_cpu_offload=True,
)
quantization_config4bit = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
llm_int8_enable_fp32_cpu_offload=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config4bit, device_map=device_map)
processor = AutoProcessor.from_pretrained(model_id)
processor_seg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64")
model_seg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64").to("cuda:1")
def draw_segments(image, masks):
red_image = Image.new('RGB', image.size, (255, 0, 0))
for mask in masks:
mask = ImageOps.fit(mask, image.size, method=Image.NEAREST)
image = Image.composite(red_image, image, mask)
return image
def get_connected_components(binary_mask):
binary_mask_np = (binary_mask.cpu().numpy() * 255).astype(np.uint8)
pil_image = Image.fromarray(binary_mask_np, mode='L')
pil_image.save("clipsegonlymask.jpg")
num_labels, labels_im = cv2.connectedComponents(binary_mask_np)
return num_labels, labels_im
def visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
overlay = image.copy()
draw = ImageDraw.Draw(overlay)
for label in range(1, num_labels):
component_mask = (labels_im == label).astype(np.uint8)
x, y, w, h = cv2.boundingRect(component_mask)
x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
return overlay
def crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h):
cropped_images = []
bboxes = []
for label in range(1, num_labels):
component_mask = (labels_im == label).astype(np.uint8)
x, y, w, h = cv2.boundingRect(component_mask)
x, y, w, h = int(x * scale_factor_w), int(y * scale_factor_h), int(w * scale_factor_w), int(h * scale_factor_h)
cropped_image = image.crop((x, y, x + w, y + h))
cropped_images.append(cropped_image)
bboxes.append([x, y, w, h])
return cropped_images, bboxes
def segment_and_crop(image_path, text, confidence=0.2):
# Load and process the image
image = Image.open(image_path).convert("RGB")
original_size = image.size
inputs = processor_seg(text=[text], images=image, return_tensors="pt").to("cuda:1")
# Perform inference
with torch.inference_mode():
outputs = model_seg(**inputs)
# Convert logits to probabilities using sigmoid
logits = outputs.logits.cpu().detach()
probs = torch.sigmoid(logits)
# Apply the confidence threshold to get binary masks
binary_mask = probs > confidence
# Find connected components in the binary mask
num_labels, labels_im = get_connected_components(binary_mask.squeeze(0))
# Calculate the scale factors between the original image and the processed image
resized_size = (352, 352) # Assuming the model resizes the image to 352x352
scale_factor_w = original_size[0] / resized_size[0]
scale_factor_h = original_size[1] / resized_size[1]
# Debug: visualize connected components
image_with_components = visualize_connected_components(image, labels_im, num_labels, scale_factor_w, scale_factor_h)
image_with_components.save("checkclipseg_components.jpg")
# Crop the objects from the image
cropped_images, bboxes = crop_objects(image, labels_im, num_labels, scale_factor_w, scale_factor_h)
# Clean up
del inputs
del outputs
gc.collect()
torch.cuda.empty_cache()
return cropped_images, bboxes
def check_image_description(image, description, model, processor):
prompt = f"Is the image fitting the description: '{description}'?"
print(image.size)
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
return 'yes' in decoded.lower()
def process_image(image_path, general_description, detailed_description):
cropped_images, bboxes = segment_and_crop(image_path, general_description)
fitting_crops = []
fitting_bboxes = []
for cropped_image, bbox in zip(cropped_images, bboxes):
#print(bbox)
if (bbox[2] * bbox[3]) < 20 or bbox[2] < 3 or bbox[3] < 3:
continue
#cropped_image.save("wtf.jpg")
#print(type(cropped_image), cropped_image.size)
if check_image_description(cropped_image, detailed_description, model, processor):
fitting_crops.append(cropped_image)
fitting_bboxes.append(bbox)
return fitting_crops, fitting_bboxes
def draw_boxes_on_image(image, bboxes, detailed_description):
draw = ImageDraw.Draw(image)
try:
# Load a truetype or opentype font file, and set the font size
font = ImageFont.truetype("arial.ttf", 20) # You can adjust the font size as needed
except IOError:
font = ImageFont.load_default() # Use the default font if the specified font is not available
for bbox in bboxes:
x, y, w, h = bbox
draw.rectangle([x, y, x + w, y + h], outline="red", width=2)
draw.text((x, y), detailed_description, fill="red", font=font)
return image
def process_dataset(dataset_path, description_pairs, output_path):
os.makedirs(output_path, exist_ok=True)
crops_path = os.path.join(output_path, "crops")
os.makedirs(crops_path, exist_ok=True)
annotated_images_path = os.path.join(output_path, "annotated_images")
os.makedirs(annotated_images_path, exist_ok=True)
annotations_path = os.path.join(output_path, "annotations")
os.makedirs(annotations_path, exist_ok=True)
annotations = {
"info": {
"description": "Generated annotations",
"version": "1.0",
"year": 2024
},
"licenses": [],
"images": [],
"annotations": [],
"categories": []
}
category_id = 1
annotations["categories"].append({
"id": category_id,
"name": "damaged pill",
"supercategory": "pill"
})
image_id = 0
annotation_id = 0
for root, _, files in tqdm(os.walk(dataset_path), desc="Processing dataset"):
for filename in files:
if not filename.endswith("jpg"):
continue
image_path = os.path.join(root, filename)
for general_description, detailed_description in description_pairs:
fitting_crops, fitting_bboxes = process_image(image_path, general_description, detailed_description)
image = Image.open(image_path).convert("RGB")
if fitting_crops:
annotations["images"].append({
"id": image_id,
"file_name": os.path.relpath(image_path, dataset_path),
"height": image.height,
"width": image.width
})
for i, (crop, bbox) in enumerate(zip(fitting_crops, fitting_bboxes)):
crop_path = os.path.join(crops_path, f"{image_id}_{i}.png")
crop.save(crop_path)
annotations["annotations"].append({
"id": annotation_id,
"image_id": image_id,
"category_id": category_id,
"bbox": bbox,
"area": bbox[2] * bbox[3],
"segmentation": [], # Optional, add segmentation if needed
"iscrowd": 0
})
annotation_id += 1
# Always save the annotated image
annotated_image = draw_boxes_on_image(image.copy(), fitting_bboxes, detailed_description)
annotated_image_save_path = os.path.join(annotated_images_path, os.path.relpath(image_path, dataset_path))
os.makedirs(os.path.dirname(annotated_image_save_path), exist_ok=True)
annotated_image.save(annotated_image_save_path)
image_id += 1
with open(os.path.join(annotations_path, "annotations.json"), "w") as f:
json.dump(annotations, f, indent=2)
# Example usage
dataset_path = "alldata/gunsss"
description_pairs = [("gun", "gun")]
output_path = "alldata/gunsss_annotated"
process_dataset(dataset_path, description_pairs, output_path)
Editor is loading...
Leave a Comment