Untitled
unknown
plain_text
a year ago
6.0 kB
12
Indexable
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import numpy as np
import cv2
import json
import torch
from dotenv import load_dotenv
load_dotenv("/home/cctv/plitter/camera_config.env")
# Import StrongSORT
from strongsort.strong_sort import StrongSORT
def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
slice_bboxes = []
y_max = y_min = 0
y_overlap = int(overlap_height_ratio * slice_height)
x_overlap = int(overlap_width_ratio * slice_width)
while y_max < image_height:
x_min = x_max = 0
y_max = y_min + slice_height
while x_max < image_width:
x_max = x_min + slice_width
if y_max > image_height or x_max > image_width:
xmax = min(image_width, x_max)
ymax = min(image_height, y_max)
xmin = max(0, xmax - slice_width)
ymin = max(0, ymax - slice_height)
slice_bboxes.append([xmin, ymin, xmax, ymax])
else:
slice_bboxes.append([x_min, y_min, x_max, y_max])
x_min = x_max - x_overlap
y_min = y_max - y_overlap
return slice_bboxes
colors = [(0, 255, 255), (0, 0, 255), (255, 0, 0), (0, 255, 0)] * 20
root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)
weights_url = os.getenv('weights_url', None)
print(root_dir, yolo_weights, FRAME_WIDTH, FRAME_HEIGHT, interval, work_in_night)
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet') not in sys.path:
sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet'))
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5') not in sys.path:
sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5'))
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (non_max_suppression, check_img_size, cv2)
data_dir = os.path.join(root_dir, 'data')
slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)
device = torch.device('cuda:0')
half = True
# Load YOLO model
if not os.path.isfile(yolo_weights):
try:
torch.hub.download_url_to_file(weights_url, yolo_weights)
except:
yolo_weights = Path(root_dir) / 'models/yolov5s.pt'
model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = 640
# Initialize StrongSORT
strongsort_tracker = StrongSORT(
reid_weights="osnet_x0_25_msmt17.pt", # Update with your ReID weights file path
device=device,
max_dist=0.2, # Distance metric threshold
max_iou_distance=0.7, # IoU threshold
max_age=30, # Max frames to keep object in memory
n_init=3, # Minimum frames to confirm the object
)
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, int(FRAME_WIDTH))
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, int(FRAME_HEIGHT))
start = '06:00:00'
end = '18:00:00'
timer = time.time()
with torch.no_grad():
while True:
current_time = datetime.now().strftime("%H:%M:%S")
if current_time >= end or current_time < start:
if work_in_night in (False, 'False'):
print('night mode turning off')
time.sleep(60)
continue
ret, img0 = cap.read()
if img0 is None:
print("Camera not connected")
continue
preds = torch.tensor([], dtype=torch.float16)
for box in slice_boxes:
img = img0[box[1]:box[3], box[0]:box[2], :]
h, w, _ = img.shape
img = cv2.resize(img, (imgsz, imgsz))
img = img.transpose((2, 0, 1))[::-1]
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img)
pred = non_max_suppression(pred, 0.4, 0.5)
proc_pred = pred[0].cpu()
for i, det in enumerate(proc_pred):
proc_pred[i][0] = proc_pred[i][0] * w / imgsz + box[0]
proc_pred[i][1] = proc_pred[i][1] * h / imgsz + box[1]
proc_pred[i][2] = proc_pred[i][2] * w / imgsz + box[0]
proc_pred[i][3] = proc_pred[i][3] * h / imgsz + box[1]
preds = torch.cat((preds, proc_pred), 0)
# Update tracker with YOLOv5 detections
xywhs = xyxy2xywh(preds[:, :4]) # Convert bbox format
confs = preds[:, 4] # Confidence scores
clss = preds[:, 5] # Class labels
# Get StrongSORT tracking output
outputs = strongsort_tracker.update(xywhs.cpu(), confs.cpu(), clss.cpu(), img0)
# Draw boxes and track IDs
for output in outputs:
bbox = output[0:4]
track_id = output[4]
cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colors[track_id % len(colors)], 2)
cv2.putText(img0, f'ID {track_id}', (int(bbox[0]), int(bbox[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, colors[track_id % len(colors)], 2)
im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
im_save = cv2.imwrite(f"{data_dir}/{im_name}.jpg", img0)
print(f"Saved: {im_save}")
time.sleep(interval)
Editor is loading...
Leave a Comment