Untitled
unknown
plain_text
a year ago
3.6 kB
10
Indexable
import os
import time
from datetime import datetime
import numpy as np
import cv2
import json
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import non_max_suppression
from strongsort.strong_sort import StrongSORT
# Your get_slice_bboxes function goes here...
# Setup
root_dir = os.getenv('root_dir', '/path/to/your/root')
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'yolov5s.pt')
data_dir = os.path.join(root_dir, 'data')
slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
imgsz = 640
device = torch.device('cuda:0')
half = True
# Create slice boxes
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)
# Load YOLO model
model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
# Initialize StrongSORT
tracker = StrongSORT(model_weights='osnet_x0_25_msmt17.pt', max_dist=0.2, max_iou_distance=0.7)
# Video Capture
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)
with torch.no_grad():
while True:
current_time = datetime.now().strftime("%H:%M:%S")
start, end = '06:00:00', '18:00:00'
if current_time >= end or current_time < start:
time.sleep(60)
continue
im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
pred_json = {'image': im_name+'.jpg', 'preds': []}
ret, img0 = cap.read()
if img0 is None:
continue
preds = torch.tensor([], dtype=torch.float16)
for box in slice_boxes:
img = img0[box[1]:box[3], box[0]:box[2], :]
h, w, _ = img.shape
h_r = h / imgsz
w_r = w / imgsz
img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
img = img.transpose((2, 0, 1))[::-1]
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img)
pred = non_max_suppression(pred, 0.4, 0.5)
proc_pred = pred[0].cpu()
for i, det in enumerate(proc_pred):
proc_pred[i][0] = proc_pred[i][0] * w_r + box[0]
proc_pred[i][1] = proc_pred[i][1] * h_r + box[1]
proc_pred[i][2] = proc_pred[i][2] * w_r + box[0]
proc_pred[i][3] = proc_pred[i][3] * h_r + box[1]
preds = torch.cat((preds, proc_pred), 0)
# Apply StrongSORT tracking
outputs = tracker.update(preds.cpu(), img0)
for output in outputs:
x1, y1, x2, y2, track_id = output[:5]
bbox = [x1, y1, x2 - x1, y2 - y1]
seg = [[x1, y1, x2, y1, x2, y2, x1, y2]]
pred_json['preds'].append({'category': names[int(output[5])], 'bbox': bbox, 'segmentation': seg, 'track_id': int(track_id)})
# Save image and JSON
cv2.imwrite(os.path.join(data_dir, im_name+'.jpg'), img0)
with open(os.path.join(data_dir, im_name+'.json'), 'w') as f:
json.dump(pred_json, f)
time.sleep(interval)
Editor is loading...
Leave a Comment