Untitled

mail@pastecode.io avatar
unknown
plain_text
20 days ago
6.0 kB
2
Indexable
Never
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import numpy as np
import cv2
import json
import torch
from dotenv import load_dotenv
load_dotenv("/home/cctv/plitter/camera_config.env")

# Import StrongSORT
from strongsort.strong_sort import StrongSORT

def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(overlap_height_ratio * slice_height)
    x_overlap = int(overlap_width_ratio * slice_width)
    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes

colors = [(0, 255, 255), (0, 0, 255), (255, 0, 0), (0, 255, 0)] * 20

root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')

FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)
weights_url = os.getenv('weights_url', None)

print(root_dir, yolo_weights, FRAME_WIDTH, FRAME_HEIGHT, interval, work_in_night)

if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet') not in sys.path:
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet'))
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5') not in sys.path:
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5'))

from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (non_max_suppression, check_img_size, cv2)

data_dir = os.path.join(root_dir, 'data')

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))

slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

device = torch.device('cuda:0')
half = True

# Load YOLO model
if not os.path.isfile(yolo_weights):
    try:
        torch.hub.download_url_to_file(weights_url, yolo_weights)
    except:
        yolo_weights = Path(root_dir) / 'models/yolov5s.pt'

model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = 640

# Initialize StrongSORT
strongsort_tracker = StrongSORT(
    reid_weights="osnet_x0_25_msmt17.pt",  # Update with your ReID weights file path
    device=device,
    max_dist=0.2,  # Distance metric threshold
    max_iou_distance=0.7,  # IoU threshold
    max_age=30,  # Max frames to keep object in memory
    n_init=3,  # Minimum frames to confirm the object
)

cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, int(FRAME_WIDTH))
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, int(FRAME_HEIGHT))

start = '06:00:00'
end = '18:00:00'
timer = time.time()

with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        if current_time >= end or current_time < start:
            if work_in_night in (False, 'False'):
                print('night mode turning off')
                time.sleep(60)
                continue

        ret, img0 = cap.read()
        if img0 is None:
            print("Camera not connected")
            continue

        preds = torch.tensor([], dtype=torch.float16)

        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            img = cv2.resize(img, (imgsz, imgsz))
            img = img.transpose((2, 0, 1))[::-1]
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()
            img /= 255.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            pred = model(img)
            pred = non_max_suppression(pred, 0.4, 0.5)

            proc_pred = pred[0].cpu()
            for i, det in enumerate(proc_pred):
                proc_pred[i][0] = proc_pred[i][0] * w / imgsz + box[0]
                proc_pred[i][1] = proc_pred[i][1] * h / imgsz + box[1]
                proc_pred[i][2] = proc_pred[i][2] * w / imgsz + box[0]
                proc_pred[i][3] = proc_pred[i][3] * h / imgsz + box[1]
            preds = torch.cat((preds, proc_pred), 0)

        # Update tracker with YOLOv5 detections
        xywhs = xyxy2xywh(preds[:, :4])  # Convert bbox format
        confs = preds[:, 4]  # Confidence scores
        clss = preds[:, 5]  # Class labels

        # Get StrongSORT tracking output
        outputs = strongsort_tracker.update(xywhs.cpu(), confs.cpu(), clss.cpu(), img0)

        # Draw boxes and track IDs
        for output in outputs:
            bbox = output[0:4]
            track_id = output[4]
            cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colors[track_id % len(colors)], 2)
            cv2.putText(img0, f'ID {track_id}', (int(bbox[0]), int(bbox[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, colors[track_id % len(colors)], 2)

        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        im_save = cv2.imwrite(f"{data_dir}/{im_name}.jpg", img0)
        print(f"Saved: {im_save}")

        time.sleep(interval)
Leave a Comment