Untitled
unknown
plain_text
a year ago
5.5 kB
9
Indexable
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import numpy as np
import cv2
import json
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (non_max_suppression)
# Assuming StrongSORT is installed and imported properly
from strongsort import StrongSORT # Update the import based on your StrongSORT implementation
def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
slice_bboxes = []
y_max = y_min = 0
y_overlap = int(overlap_height_ratio * slice_height)
x_overlap = int(overlap_width_ratio * slice_width)
while y_max < image_height:
x_min = x_max = 0
y_max = y_min + slice_height
while x_max < image_width:
x_max = x_min + slice_width
if y_max > image_height or x_max > image_width:
xmax = min(image_width, x_max)
ymax = min(image_height, y_max)
xmin = max(0, xmax - slice_width)
ymin = max(0, ymax - slice_height)
slice_bboxes.append([xmin, ymin, xmax, ymax])
else:
slice_bboxes.append([x_min, y_min, x_max, y_max])
x_min = x_max - x_overlap
y_min = y_max - y_overlap
return slice_bboxes
# Setup environment
root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')
# Camera and model settings
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)
slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)
device = torch.device('cuda:0')
half = True
# Load YOLOv5 model
if not os.path.isfile(yolo_weights):
weights_url = os.getenv('weights_url', None)
try:
torch.hub.download_url_to_file(weights_url, yolo_weights)
except:
yolo_weights = Path(root_dir) / 'models/yolov5s.pt'
model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
model.eval()
# Initialize StrongSORT
strongsort = StrongSORT()
data_dir = os.path.join(root_dir, 'data')
imgsz = 640
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)
# Processing loop
with torch.no_grad():
while True:
current_time = datetime.now().strftime("%H:%M:%S")
# Night mode check (time-based)
if work_in_night and (current_time >= '18:00:00' or current_time < '06:00:00'):
print('Night mode active. Sleeping...')
time.sleep(60)
continue
im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
pred_json = {'image': im_name + '.jpg', 'preds': []}
ret, img0 = cap.read()
if img0 is None:
print("Check if the camera is connected and run again.")
continue
preds = torch.tensor([], dtype=torch.float16)
# Process each slice of the image
for box in slice_boxes:
img = img0[box[1]:box[3], box[0]:box[2], :]
h, w, _ = img.shape
img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # Convert to FP16 or FP32
img /= 255.0 # Normalize to [0, 1]
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img)
pred = non_max_suppression(pred, 0.4, 0.5)
# Process predictions
proc_pred = pred[0].cpu()
for i, det in enumerate(proc_pred):
# Rescale box coordinates
det[0] = det[0] * (w / imgsz) + box[0]
det[1] = det[1] * (h / imgsz) + box[1]
det[2] = det[2] * (w / imgsz) + box[0]
det[3] = det[3] * (h / imgsz) + box[1]
preds = torch.cat((preds, proc_pred), 0)
# Track the objects using StrongSORT
track_ids = strongsort.update(preds, img0)
# Prepare JSON output
for pred, track_id in zip(preds, track_ids):
pred = pred.tolist()
bbox = [pred[0], pred[1], pred[2] - pred[0], pred[3] - pred[1]]
seg = [[pred[0], pred[1], pred[2], pred[1], pred[2], pred[3], pred[0], pred[3]]]
pred_json['preds'].append({
'category': model.names[int(pred[5])],
'bbox': bbox,
'segmentation': seg,
'track_id': track_id # Include the tracking ID
})
# Save image and JSON
cv2.imwrite(os.path.join(data_dir, im_name + '.jpg'), img0)
with open(os.path.join(data_dir, im_name + '.json'), 'w') as json_file:
json.dump(pred_json, json_file)
print(f"Saved {im_name}.jpg and {im_name}.json")
time.sleep(interval)
Editor is loading...
Leave a Comment