Untitled

mail@pastecode.io avatar
unknown
plain_text
17 days ago
3.6 kB
2
Indexable
Never
import os
import time
from datetime import datetime
import numpy as np
import cv2
import json
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import non_max_suppression
from strongsort.strong_sort import StrongSORT

# Your get_slice_bboxes function goes here...

# Setup
root_dir = os.getenv('root_dir', '/path/to/your/root')
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'yolov5s.pt')
data_dir = os.path.join(root_dir, 'data')

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
imgsz = 640
device = torch.device('cuda:0')
half = True

# Create slice boxes
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

# Load YOLO model
model = DetectMultiBackend(yolo_weights, device=device, fp16=half)

# Initialize StrongSORT
tracker = StrongSORT(model_weights='osnet_x0_25_msmt17.pt', max_dist=0.2, max_iou_distance=0.7)

# Video Capture
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)

with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        start, end = '06:00:00', '18:00:00'
        if current_time >= end or current_time < start:
            time.sleep(60)
            continue

        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        pred_json = {'image': im_name+'.jpg', 'preds': []}
        
        ret, img0 = cap.read()
        if img0 is None:
            continue
        
        preds = torch.tensor([], dtype=torch.float16)
        
        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            h_r = h / imgsz
            w_r = w / imgsz
            img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
            img = img.transpose((2, 0, 1))[::-1]
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()
            img /= 255.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            pred = model(img)
            pred = non_max_suppression(pred, 0.4, 0.5)
            proc_pred = pred[0].cpu()
            
            for i, det in enumerate(proc_pred):
                proc_pred[i][0] = proc_pred[i][0] * w_r + box[0]
                proc_pred[i][1] = proc_pred[i][1] * h_r + box[1]
                proc_pred[i][2] = proc_pred[i][2] * w_r + box[0]
                proc_pred[i][3] = proc_pred[i][3] * h_r + box[1]
            preds = torch.cat((preds, proc_pred), 0)
        
        # Apply StrongSORT tracking
        outputs = tracker.update(preds.cpu(), img0)
        
        for output in outputs:
            x1, y1, x2, y2, track_id = output[:5]
            bbox = [x1, y1, x2 - x1, y2 - y1]
            seg = [[x1, y1, x2, y1, x2, y2, x1, y2]]
            pred_json['preds'].append({'category': names[int(output[5])], 'bbox': bbox, 'segmentation': seg, 'track_id': int(track_id)})
        
        # Save image and JSON
        cv2.imwrite(os.path.join(data_dir, im_name+'.jpg'), img0)
        with open(os.path.join(data_dir, im_name+'.json'), 'w') as f:
            json.dump(pred_json, f)
        
        time.sleep(interval)
Leave a Comment