Untitled

mail@pastecode.io avatar
unknown
plain_text
16 days ago
2.1 kB
3
Indexable
Never
import torch
import torchvision
from torchvision import transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import os
import glob

# Load the pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

# Define the transform
transform = T.Compose([
    T.ToTensor(),
])

# Directory containing the images to process
image_dir = 'promotion-staff'
output_dir = 'cropped_images'
os.makedirs(output_dir, exist_ok=True)

# Process each image in the directory
for image_path in glob.glob(os.path.join(image_dir, '*.jpg')):
    print(f"Processing {image_path}...")
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0)

    # Get the predictions
    with torch.no_grad():
        predictions = model(image_tensor)

    # Convert the image to numpy array for plotting
    image_np = np.array(image)

    # Create a figure and axis for the plot
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(image_np)

    # Iterate over the predicted boxes
    for i, box in enumerate(predictions[0]['boxes']):
        if predictions[0]['scores'][i] > 0.5:
            x1, y1, x2, y2 = box.cpu().numpy()
            label = predictions[0]['labels'][i].item()
            score = predictions[0]['scores'][i].item()

            # Draw the bounding box on the plot
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='g', facecolor='none')
            ax.add_patch(rect)
            plt.text(x1, y1 - 10, f"{label}: {score:.2f}", color='g', fontsize=12, bbox=dict(facecolor='white', alpha=0.5))

            # Crop the image and save it
            if label != 1: continue
            cropped_image = image.crop((x1, y1, x2, y2))
            cropped_image_path = os.path.join(output_dir, f"{os.path.basename(image_path).split('.')[0]}_cropped_{i}_label_{label}_score_{score:.2f}.jpg")
            cropped_image.save(cropped_image_path)

    # Show the plot (optional)
    # plt.show()
Leave a Comment