import os
import sys
import time
from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import cv2
import json
import torch
from dotenv import load_dotenv

# Import StrongSORT
from strongsort.strong_sort import StrongSORT

def get_slice_bboxes(image_height, image_width, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio):
    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(overlap_height_ratio * slice_height)
    x_overlap = int(overlap_width_ratio * slice_width)
    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes

colors = [(0, 255, 255), (0, 0, 255), (255, 0, 0), (0, 255, 0)] * 20

root_dir = os.getenv('root_dir', '/'.join(os.path.abspath(__file__).split('/')[0:-2]))
yolo_weights = Path(root_dir) / 'models' / os.getenv('weights', 'pLitterFloat_800x752_to_640x640.pt')

FRAME_WIDTH = int(os.getenv('frame_width', 1920))
FRAME_HEIGHT = int(os.getenv('frame_height', 1080))
interval = int(os.getenv('interval', 10))
work_in_night = os.getenv('work_in_night', True)
weights_url = os.getenv('weights_url', None)

print(root_dir, yolo_weights, FRAME_WIDTH, FRAME_HEIGHT, interval, work_in_night)

if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet') not in sys.path:
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet'))
if os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5') not in sys.path:
    sys.path.append(os.path.join(root_dir, 'Yolov5_StrongSORT_OSNet/yolov5'))

from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (non_max_suppression, check_img_size, cv2)

data_dir = os.path.join(root_dir, 'data')

slice_width = int(os.getenv("slice_width", 800))
slice_height = int(os.getenv("slice_height", 752))

slice_boxes = get_slice_bboxes(FRAME_HEIGHT, FRAME_WIDTH, slice_height, slice_width, 0.04, 0.04)

device = torch.device('cuda:0')
half = True

# Load YOLO model
if not os.path.isfile(yolo_weights):
        torch.hub.download_url_to_file(weights_url, yolo_weights)
        yolo_weights = Path(root_dir) / 'models/yolov5s.pt'

model = DetectMultiBackend(yolo_weights, device=device, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = 640

# Initialize StrongSORT
strongsort_tracker = StrongSORT(
    reid_weights="osnet_x0_25_msmt17.pt",  # Update with your ReID weights file path
    max_dist=0.2,  # Distance metric threshold
    max_iou_distance=0.7,  # IoU threshold
    max_age=30,  # Max frames to keep object in memory
    n_init=3,  # Minimum frames to confirm the object

cap = cv2.VideoCapture(0)

start = '06:00:00'
end = '18:00:00'
timer = time.time()

with torch.no_grad():
    while True:
        current_time = datetime.now().strftime("%H:%M:%S")
        if current_time >= end or current_time < start:
            if work_in_night in (False, 'False'):
                print('night mode turning off')

        ret, img0 = cap.read()
        if img0 is None:
            print("Camera not connected")

        preds = torch.tensor([], dtype=torch.float16)

        for box in slice_boxes:
            img = img0[box[1]:box[3], box[0]:box[2], :]
            h, w, _ = img.shape
            img = cv2.resize(img, (imgsz, imgsz))
            img = img.transpose((2, 0, 1))[::-1]
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(device)
            img = img.half() if half else img.float()
            img /= 255.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            pred = model(img)
            pred = non_max_suppression(pred, 0.4, 0.5)

            proc_pred = pred[0].cpu()
            for i, det in enumerate(proc_pred):
                proc_pred[i][0] = proc_pred[i][0] * w / imgsz + box[0]
                proc_pred[i][1] = proc_pred[i][1] * h / imgsz + box[1]
                proc_pred[i][2] = proc_pred[i][2] * w / imgsz + box[0]
                proc_pred[i][3] = proc_pred[i][3] * h / imgsz + box[1]
            preds = torch.cat((preds, proc_pred), 0)

        # Update tracker with YOLOv5 detections
        xywhs = xyxy2xywh(preds[:, :4])  # Convert bbox format
        confs = preds[:, 4]  # Confidence scores
        clss = preds[:, 5]  # Class labels

        # Get StrongSORT tracking output
        outputs = strongsort_tracker.update(xywhs.cpu(), confs.cpu(), clss.cpu(), img0)

        # Draw boxes and track IDs
        for output in outputs:
            bbox = output[0:4]
            track_id = output[4]
            cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colors[track_id % len(colors)], 2)
            cv2.putText(img0, f'ID {track_id}', (int(bbox[0]), int(bbox[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, colors[track_id % len(colors)], 2)

        im_name = datetime.now().strftime("%Y%m%d_%H%M%S")
        im_save = cv2.imwrite(f"{data_dir}/{im_name}.jpg", img0)
        print(f"Saved: {im_save}")

