Untitled

mail@pastecode.io avatar
unknown
plain_text
17 days ago
5.5 kB
1
Indexable
Never
import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import numpy as np
import cv2
import json
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (non_max_suppression)

# Assuming StrongSORT is installed and imported properly
from strongsort import StrongSORT  # Update the import based on your StrongSORT implementation

def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(overlap_height_ratio * slice_height)
    x_overlap = int(overlap_width_ratio * slice_width)
    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes

# Setup environment
root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')

# Camera and model settings
FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))
slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

device = torch.device('cuda:0')
half = True

# Load YOLOv5 model
if not os.path.isfile(yolo_weights):
    weights_url = os.getenv('weights_url', None)
    try:
        torch.hub.download_url_to_file(weights_url, yolo_weights)
    except:
        yolo_weights = Path(root_dir) / 'models/yolov5s.pt'

model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
model.eval()

# Initialize StrongSORT
strongsort = StrongSORT()

data_dir = os.path.join(root_dir, 'data')
imgsz = 640

cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)

# Processing loop
with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        # Night mode check (time-based)
        if work_in_night and (current_time >= '18:00:00' or current_time < '06:00:00'):
            print('Night mode active. Sleeping...')
            time.sleep(60)
            continue
        
        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        pred_json = {'image': im_name + '.jpg', 'preds': []}

        ret, img0 = cap.read()
        if img0 is None:
            print("Check if the camera is connected and run again.")
            continue
        
        preds = torch.tensor([], dtype=torch.float16)
        
        # Process each slice of the image
        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            img = cv2.resize(img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
            img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()  # Convert to FP16 or FP32
            img /= 255.0  # Normalize to [0, 1]
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            pred = model(img)
            pred = non_max_suppression(pred, 0.4, 0.5)

            # Process predictions
            proc_pred = pred[0].cpu()
            for i, det in enumerate(proc_pred):
                # Rescale box coordinates
                det[0] = det[0] * (w / imgsz) + box[0]
                det[1] = det[1] * (h / imgsz) + box[1]
                det[2] = det[2] * (w / imgsz) + box[0]
                det[3] = det[3] * (h / imgsz) + box[1]
            preds = torch.cat((preds, proc_pred), 0)

        # Track the objects using StrongSORT
        track_ids = strongsort.update(preds, img0)

        # Prepare JSON output
        for pred, track_id in zip(preds, track_ids):
            pred = pred.tolist()
            bbox = [pred[0], pred[1], pred[2] - pred[0], pred[3] - pred[1]]
            seg = [[pred[0], pred[1], pred[2], pred[1], pred[2], pred[3], pred[0], pred[3]]]
            pred_json['preds'].append({
                'category': model.names[int(pred[5])],
                'bbox': bbox,
                'segmentation': seg,
                'track_id': track_id  # Include the tracking ID
            })

        # Save image and JSON
        cv2.imwrite(os.path.join(data_dir, im_name + '.jpg'), img0)
        with open(os.path.join(data_dir, im_name + '.json'), 'w') as json_file:
            json.dump(pred_json, json_file)

        print(f"Saved {im_name}.jpg and {im_name}.json")
        time.sleep(interval)
Leave a Comment