Untitled
unknown
plain_text
20 days ago
6.0 kB
2
Indexable
Never
import os import sys import time from datetime import datetime from pathlib import Path import matplotlib matplotlib.use("Agg") import numpy as np import cv2 import json import torch from dotenv import load_dotenv load_dotenv("/home/cctv/plitter/camera_config.env") # Import StrongSORT from strongsort.strong_sort import StrongSORT def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio): slice_bboxes = [] y_max = y_min = 0 y_overlap = int(overlap_height_ratio * slice_height) x_overlap = int(overlap_width_ratio * slice_width) while y_max < image_height: x_min = x_max = 0 y_max = y_min + slice_height while x_max < image_width: x_max = x_min + slice_width if y_max > image_height or x_max > image_width: xmax = min(image_width, x_max) ymax = min(image_height, y_max) xmin = max(0, xmax - slice_width) ymin = max(0, ymax - slice_height) slice_bboxes.append([xmin, ymin, xmax, ymax]) else: slice_bboxes.append([x_min, y_min, x_max, y_max]) x_min = x_max - x_overlap y_min = y_max - y_overlap return slice_bboxes colors = [(0, 255, 255), (0, 0, 255), (255, 0, 0), (0, 255, 0)] * 20 root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2])) yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt') FRAME_WIDTH = int(os.getenv('frame_width', 1920)) FRAME_HEIGHT = int(os.getenv('frame_height', 1080)) interval = int(os.getenv('interval', 10)) work_in_night = os.getenv('work_in_night', True) weights_url = os.getenv('weights_url', None) print(root_dir, yolo_weights, FRAME_WIDTH, FRAME_HEIGHT, interval, work_in_night) if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet') not in sys.path: sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet')) if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5') not in sys.path: sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5')) from yolov5.models.common import DetectMultiBackend from yolov5.utils.general import (non_max_suppression, check_img_size, cv2) data_dir = os.path.join(root_dir, 'data') slice_width = int(os.getenv("slice_width", 800)) slice_height = int(os.getenv("slice_height", 752)) slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04) device = torch.device('cuda:0') half = True # Load YOLO model if not os.path.isfile(yolo_weights): try: torch.hub.download_url_to_file(weights_url, yolo_weights) except: yolo_weights = Path(root_dir) / 'models/yolov5s.pt' model = DetectMultiBackend(yolo_weights, device=device, fp16=half) stride, names, pt = model.stride, model.names, model.pt imgsz = 640 # Initialize StrongSORT strongsort_tracker = StrongSORT( reid_weights="osnet_x0_25_msmt17.pt", # Update with your ReID weights file path device=device, max_dist=0.2, # Distance metric threshold max_iou_distance=0.7, # IoU threshold max_age=30, # Max frames to keep object in memory n_init=3, # Minimum frames to confirm the object ) cap = cv2.VideoCapture(0) cap.set(cv2.CAP_PROP_FRAME_WIDTH, int(FRAME_WIDTH)) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, int(FRAME_HEIGHT)) start = '06:00:00' end = '18:00:00' timer = time.time() with torch.no_grad(): while True: current_time = datetime.now().strftime("%H:%M:%S") if current_time >= end or current_time < start: if work_in_night in (False, 'False'): print('night mode turning off') time.sleep(60) continue ret, img0 = cap.read() if img0 is None: print("Camera not connected") continue preds = torch.tensor([], dtype=torch.float16) for box in slice_boxes: img = img0[box[1]:box[3], box[0]:box[2], :] h, w, _ = img.shape img = cv2.resize(img, (imgsz, imgsz)) img = img.transpose((2, 0, 1))[::-1] img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() img /= 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) pred = model(img) pred = non_max_suppression(pred, 0.4, 0.5) proc_pred = pred[0].cpu() for i, det in enumerate(proc_pred): proc_pred[i][0] = proc_pred[i][0] * w / imgsz + box[0] proc_pred[i][1] = proc_pred[i][1] * h / imgsz + box[1] proc_pred[i][2] = proc_pred[i][2] * w / imgsz + box[0] proc_pred[i][3] = proc_pred[i][3] * h / imgsz + box[1] preds = torch.cat((preds, proc_pred), 0) # Update tracker with YOLOv5 detections xywhs = xyxy2xywh(preds[:, :4]) # Convert bbox format confs = preds[:, 4] # Confidence scores clss = preds[:, 5] # Class labels # Get StrongSORT tracking output outputs = strongsort_tracker.update(xywhs.cpu(), confs.cpu(), clss.cpu(), img0) # Draw boxes and track IDs for output in outputs: bbox = output[0:4] track_id = output[4] cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colors[track_id % len(colors)], 2) cv2.putText(img0, f'ID {track_id}', (int(bbox[0]), int(bbox[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, colors[track_id % len(colors)], 2) im_name = datetime.now().strftime("%Y%m%d_%H%M%S") im_save = cv2.imwrite(f"{data_dir}/{im_name}.jpg", img0) print(f"Saved: {im_save}") time.sleep(interval)
Leave a Comment