Untitled

 avatar
unknown
plain_text
a year ago
3.6 kB
10
Indexable
import os
import time
from datetime import datetime
import numpy as np
import cv2
import json
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import non_max_suppression
from strongsort.strong_sort import StrongSORT

# Your get_slice_bboxes function goes here...

# Setup
root_dir = os.getenv('root_dir', '/path/to/your/root')
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'yolov5s.pt')
data_dir = os.path.join(root_dir, 'data')

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
imgsz = 640
device = torch.device('cuda:0')
half = True

# Create slice boxes
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

# Load YOLO model
model = DetectMultiBackend(yolo_weights, device=device, fp16=half)

# Initialize StrongSORT
tracker = StrongSORT(model_weights='osnet_x0_25_msmt17.pt', max_dist=0.2, max_iou_distance=0.7)

# Video Capture
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)

with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        start, end = '06:00:00', '18:00:00'
        if current_time >= end or current_time < start:
            time.sleep(60)
            continue

        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        pred_json = {'image': im_name+'.jpg', 'preds': []}
        
        ret, img0 = cap.read()
        if img0 is None:
            continue
        
        preds = torch.tensor([], dtype=torch.float16)
        
        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            h_r = h / imgsz
            w_r = w / imgsz
            img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
            img = img.transpose((2, 0, 1))[::-1]
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()
            img /= 255.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            pred = model(img)
            pred = non_max_suppression(pred, 0.4, 0.5)
            proc_pred = pred[0].cpu()
            
            for i, det in enumerate(proc_pred):
                proc_pred[i][0] = proc_pred[i][0] * w_r + box[0]
                proc_pred[i][1] = proc_pred[i][1] * h_r + box[1]
                proc_pred[i][2] = proc_pred[i][2] * w_r + box[0]
                proc_pred[i][3] = proc_pred[i][3] * h_r + box[1]
            preds = torch.cat((preds, proc_pred), 0)
        
        # Apply StrongSORT tracking
        outputs = tracker.update(preds.cpu(), img0)
        
        for output in outputs:
            x1, y1, x2, y2, track_id = output[:5]
            bbox = [x1, y1, x2 - x1, y2 - y1]
            seg = [[x1, y1, x2, y1, x2, y2, x1, y2]]
            pred_json['preds'].append({'category': names[int(output[5])], 'bbox': bbox, 'segmentation': seg, 'track_id': int(track_id)})
        
        # Save image and JSON
        cv2.imwrite(os.path.join(data_dir, im_name+'.jpg'), img0)
        with open(os.path.join(data_dir, im_name+'.json'), 'w') as f:
            json.dump(pred_json, f)
        
        time.sleep(interval)
Editor is loading...
Leave a Comment