Untitled

mail@pastecode.io avatar
unknown
plain_text
17 days ago
9.4 kB
2
Indexable
Never
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import numpy as np
import cv2
import json
import torch
import sqlite3
import uuid
from dotenv import load_dotenv
load_dotenv("/home/cctv/plitter/camera_config.env")

def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(overlap_height_ratio * slice_height)
    x_overlap = int(overlap_width_ratio * slice_width)
    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes

colors = [(0, 255, 255), (0, 0, 255), (255, 0, 0), (0, 255, 0)] * 20

def draw_boxes_on_image(image, boxes, classes, class_ids, scores, use_normalized_coordinates=False, min_score_thresh=.3):
    assert len(boxes) == len(scores)
    for i in range(len(boxes)):
        box = boxes[i]
        category = str(classes[i])
        class_id = int(class_ids[i])
        score = scores[i]
        if score >= min_score_thresh:
            if use_normalized_coordinates:
                h, w, _ = image.shape
                y1 = int(box[0] * h)
                x1 = int(box[1] * w)
                y2 = int(box[2] * h)
                x2 = int(box[3] * w)
            else:
                x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
            image = cv2.rectangle(image, (x1, y1), (x2, y2), colors[class_id], 2)
            cv2.putText(image, category + ':' + str(round(score, 2)), (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[class_id], 1)
    return image

# Set env var with desired path
root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))

yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')
reid_weights = Path(root_dir) / 'models' / os.getenv('reid_weights', 'osnet_x0_25_msmt17.pt')

FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1280))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)
weights_url = os.getenv('weights_url', None)

print(root_dir, yolo_weights, reid_weights, FRAME_WIDTH, FRAME_HEIGHT, interval, work_in_night)

if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet') not in sys.path:
    print(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet'))
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet'))
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5') not in sys.path:
    print(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5'))
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5'))
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/trackers/strong_sort') not in sys.path:
    print(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/trackers/strong_sort'))
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/trackers/strong_sort'))

from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes, check_requirements, cv2,
                                   check_imshow, xyxy2xywh, increment_path, strip_optimizer, colorstr, print_args, check_file)

from trackers.strong_sort.utils.parser import get_config
from trackers.strong_sort.strong_sort import StrongSORT

db_dir = os.path.join(root_dir, 'db')
data_dir = os.path.join(root_dir, 'data')
os.makedirs(db_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)

# Generate a unique string for saving tracking ids
uid = str(uuid.uuid4())

# Database paths
detections_dbpath = os.path.join(db_dir, 'detections.db')  # Change this
images_dbpath = os.path.join(db_dir, 'images.db')

conn = sqlite3.connect(detections_dbpath, isolation_level=None)
conn.execute("VACUUM")
cur = conn.cursor()
cur.execute("""CREATE TABLE IF NOT EXISTS detections(id INTEGER PRIMARY KEY, track_id TEXT, date_time TEXT, category TEXT, bbox TEXT, segmentation TEXT)""")
conn.commit()

im_conn = sqlite3.connect(images_dbpath, isolation_level=None)
im_conn.execute("VACUUM")
im_cur = im_conn.cursor()
im_cur.execute("""CREATE TABLE IF NOT EXISTS images(id INTEGER PRIMARY KEY, file_name TEXT UNIQUE, uploaded BOOLEAN)""")
im_conn.commit()

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))

slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

device = torch.device('cuda:0')
half = True

# Load YOLOv5 model
if not os.path.isfile(yolo_weights):
    try:
        torch.hub.download_url_to_file(weights_url, yolo_weights)
    except:
        yolo_weights = Path(root_dir) / 'models/yolov5s.pt'
        pass

model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
print(model.names)

# Load StrongSORT configuration
cfg = get_config()
cfg.merge_from_file(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/trackers/strong_sort/configs/strong_sort.yaml'))

# Initialize StrongSORT tracker
tracker = StrongSORT(
    reid_weights,
    device,
    half,
    max_dist=cfg.STRONGSORT.MAX_DIST,
    max_iou_distance=cfg.STRONGSORT.MAX_IOU_DISTANCE,
    max_age=cfg.STRONGSORT.MAX_AGE,
    max_unmatched_preds=99,
    n_init=0,
    nn_budget=cfg.STRONGSORT.NN_BUDGET,
    mc_lambda=cfg.STRONGSORT.MC_LAMBDA,
    ema_alpha=cfg.STRONGSORT.EMA_ALPHA,
)

tracker.model.warmup()

t0 = time.time()

cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, int(FRAME_WIDTH))
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, int(FRAME_HEIGHT))

imgsz = 640

cur = conn.cursor()
im_cur = im_conn.cursor()

prev_frame, curr_frame = None, None

start = '06:00:00'
end = '18:00:00'
timer = time.time()

with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        if current_time >= end or current_time < start:
            if work_in_night in (False, 'False'):
                print('night mode turning off')
                time.sleep(60)
                continue

        st = time.time()
        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")

        ret, img0 = cap.read()
        curr_frame = img0

        if img0 is None:
            print("Check if camera is connected and run again \n")
            continue

        if img0.all() is None:
            continue

        preds = torch.tensor([], dtype=torch.float16)

        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            h_r = h / imgsz
            w_r = w / imgsz
            img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
            img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW
            img = np.ascontiguousarray(img)  # contiguous

            img = torch.from_numpy(img).to(device).half() if half else torch.from_numpy(img).to(device)
            img = img.unsqueeze(0)  # add batch dimension

            pred = model(img, augment=False, visualize=False)[0]  # inference

            pred = non_max_suppression(pred, 0.25, 0.45, agnostic=False)[0]  # NMS

            if pred is not None and len(pred):
                # Rescale boxes from img_size to frame size
                pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], img0.shape).round()
                preds = torch.cat((preds, pred), dim=0)

        if preds is not None and len(preds):
            track_ids = tracker.update(preds.cpu(), (img0.shape[1], img0.shape[0]))
            for i, (pred, track_id) in enumerate(zip(preds, track_ids)):
                x1, y1, x2, y2, conf, cls = pred
                x1, y1, x2, y2, track_id = int(x1), int(y1), int(x2), int(y2), int(track_id)

                # Prepare detection data for database
                bbox = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
                segmentation = {}
                im_cur.execute(f"INSERT OR IGNORE INTO images(file_name, uploaded) VALUES (?, ?)", (im_name, False))
                conn.execute("INSERT INTO detections(track_id, date_time, category, bbox, segmentation) VALUES (?, ?, ?, ?, ?)",
                             (track_id, datetime.now(), str(cls), json.dumps(bbox), json.dumps(segmentation)))
                conn.commit()

        img0 = draw_boxes_on_image(img0, preds[:, :4], names, preds[:, 5], preds[:, 4])
        cv2.imshow("image", img0)
        time.sleep(max(0, interval - (time.time() - st)))

        if cv2.waitKey(1) == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()
conn.close()
im_conn.close()
Leave a Comment