Untitled
unknown
plain_text
a year ago
2.1 kB
10
Indexable
import torch
import torchvision
from torchvision import transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import os
import glob
# Load the pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# Define the transform
transform = T.Compose([
T.ToTensor(),
])
# Directory containing the images to process
image_dir = 'promotion-staff'
output_dir = 'cropped_images'
os.makedirs(output_dir, exist_ok=True)
# Process each image in the directory
for image_path in glob.glob(os.path.join(image_dir, '*.jpg')):
print(f"Processing {image_path}...")
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
# Get the predictions
with torch.no_grad():
predictions = model(image_tensor)
# Convert the image to numpy array for plotting
image_np = np.array(image)
# Create a figure and axis for the plot
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(image_np)
# Iterate over the predicted boxes
for i, box in enumerate(predictions[0]['boxes']):
if predictions[0]['scores'][i] > 0.5:
x1, y1, x2, y2 = box.cpu().numpy()
label = predictions[0]['labels'][i].item()
score = predictions[0]['scores'][i].item()
# Draw the bounding box on the plot
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='g', facecolor='none')
ax.add_patch(rect)
plt.text(x1, y1 - 10, f"{label}: {score:.2f}", color='g', fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
# Crop the image and save it
if label != 1: continue
cropped_image = image.crop((x1, y1, x2, y2))
cropped_image_path = os.path.join(output_dir, f"{os.path.basename(image_path).split('.')[0]}_cropped_{i}_label_{label}_score_{score:.2f}.jpg")
cropped_image.save(cropped_image_path)
# Show the plot (optional)
# plt.show()
Editor is loading...
Leave a Comment