nahrixt
unknown
plain_text
5 months ago
4.6 kB
5
Indexable
import cv2 import torch import numpy as np from ultralytics import YOLO from deep_sort_realtime.deepsort_tracker import DeepSort class_names = [ "motorbike", #0 "DHelmet", #1 "DNoHelmet", #2 "P1Helmet", #3 "P1NoHelmet", #4 "P2Helmet", #5 "P2NoHelmet", #6 "P0Helmet", #7 "P0NoHelmet", #8 ] # Danh sách tên class minority_class = [3, 5, 6, 7, 8] minority_conf = 0.0007 conf_threshold = 0.5 tracking_class = None # None: track all tracker = DeepSort( max_age=20, n_init=3, max_cosine_distance=0.3, max_iou_distance=1, nms_max_overlap=1, ) # tracker = DeepSort( # max_age=20, # Faster updates # n_init=2, # Quick track confirmation # max_cosine_distance=0.3, # Looser feature vector matching # max_iou_distance=0.8, # Higher IoU for track association # nms_max_overlap=0.5, # Stricter NMS threshold # ) model = YOLO("C:\\study\\ComputerVision\\HelmetDetection\\yolo_tracking\\weights\\v10m.pt") colors = np.random.randint(0, 255, size=(len(class_names), 3)) tracks = [] frame_count = 0 detect_interval = 1 # Số frame giữa các lần YOLO detect tracking = {} # cap = cv2.VideoCapture("./10fps.mp4") cap = cv2.VideoCapture("C:\\study\\ComputerVision\\HelmetDetection\\yolo_tracking\\source\\camera.mp4") if not cap.isOpened(): print("Error: Cannot open video file.") exit() if model is None: print("Error: YOLO model could not be loaded.") exit() # Tiến hành đọc từng frame từ video while True: # Đọc ret, frame = cap.read() if not ret: # continue break # Đưa qua model để detect results = model(frame) # print("results : ", results) # Kiểm tra kết quả từ mô hình (results có thể là một danh sách) for result in results: # Truy cập vào các thuộc tính của từng kết quả boxes = result.boxes # Đối tượng Boxes chứa các bounding box detections = [] for det in boxes: # Lấy các giá trị bounding box x1, y1, x2, y2 = det.xyxy[0] # Lấy tọa độ bounding box conf = det.conf[0] # Độ tin cậy của đối tượng cls = det.cls[0] # ID của lớp đối tượng # Kiểm tra độ tin cậy if cls in minority_class: if conf < minority_conf: continue elif conf < conf_threshold: continue # if conf < conf_threshold: # continue # Thêm phát hiện vào danh sách detect = [[int(x1), int(y1), int(x2 - x1), int(y2 - y1)], conf, int(cls)] detections.append(detect) # det_array = detections if len(detections) > 0 else np.empty((0, 5)) # tracks = tracker.update_tracks(det_array, frame=frame) tracks = tracker.update_tracks(detections, frame=frame) # Vẽ lên màn hình các khung chữ nhật kèm ID for track in tracks: if track.is_confirmed(): track_id = track.track_id # Lấy toạ độ, class_id để vẽ lên hình ảnh ltrb = track.to_ltrb() class_id = track.get_det_class() x1, y1, x2, y2 = map(int, ltrb) color = colors[class_id] B, G, R = map(int,color) if track_id not in tracking: tracking[track_id] = [] tracking[track_id].append(class_id) tracking[track_id].append(B) tracking[track_id].append(G) tracking[track_id].append(R) else: class_id = tracking[track_id][0] B, G, R = tracking[track_id][1:] # label = "{}-{}".format(class_names[class_id], track_id) label = "{}-{}".format(class_names[class_id], track_id) cv2.rectangle(frame, (x1, y1), (x2, y2), (B, G, R), 2) cv2.rectangle(frame, (x1 - 1, y1 - 20), (x1 + len(label) * 12, y1), (B, G, R), -1) cv2.putText(frame, label, (x1 + 5, y1 - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) frame_count += 1 # Show hình ảnh lên màn hình cv2.imshow("OT", frame) cv2.waitKey(3) # Bấm Q thì thoát if cv2.waitKey(1) == ord("q"): break cap.release() cv2.destroyAllWindows()
Editor is loading...
Leave a Comment