Untitled
unknown
plain_text
a year ago
2.1 kB
5
Indexable
import torch import torchvision from torchvision import transforms as T from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import os import glob # Load the pre-trained model model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) model.eval() # Define the transform transform = T.Compose([ T.ToTensor(), ]) # Directory containing the images to process image_dir = 'promotion-staff' output_dir = 'cropped_images' os.makedirs(output_dir, exist_ok=True) # Process each image in the directory for image_path in glob.glob(os.path.join(image_dir, '*.jpg')): print(f"Processing {image_path}...") image = Image.open(image_path).convert("RGB") image_tensor = transform(image).unsqueeze(0) # Get the predictions with torch.no_grad(): predictions = model(image_tensor) # Convert the image to numpy array for plotting image_np = np.array(image) # Create a figure and axis for the plot fig, ax = plt.subplots(1, figsize=(12, 9)) ax.imshow(image_np) # Iterate over the predicted boxes for i, box in enumerate(predictions[0]['boxes']): if predictions[0]['scores'][i] > 0.5: x1, y1, x2, y2 = box.cpu().numpy() label = predictions[0]['labels'][i].item() score = predictions[0]['scores'][i].item() # Draw the bounding box on the plot rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='g', facecolor='none') ax.add_patch(rect) plt.text(x1, y1 - 10, f"{label}: {score:.2f}", color='g', fontsize=12, bbox=dict(facecolor='white', alpha=0.5)) # Crop the image and save it if label != 1: continue cropped_image = image.crop((x1, y1, x2, y2)) cropped_image_path = os.path.join(output_dir, f"{os.path.basename(image_path).split('.')[0]}_cropped_{i}_label_{label}_score_{score:.2f}.jpg") cropped_image.save(cropped_image_path) # Show the plot (optional) # plt.show()
Editor is loading...
Leave a Comment