from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
from ultralytics import YOLO
from ultralytics.engine.results import Results

from path_manager import Paths


@dataclass(frozen=True)
class Target:
    confidence: float
    center: tuple[int, int] | None = None
    size: tuple[int, int] | None = None
    box: tuple[int, int, int, int] | None = None
    is_lost: bool = True


class YoloSegmentation:
    def __init__(self, model_path: str | Path = Paths.YOLO_SEGMENTATION_MODEL_PATH, confidence_threshold: float = 0.7):
        self.model = YOLO(model_path)
        self.confidence_threshold = confidence_threshold

        self.alpha = 0.5
        self.mask_confidence = 0.5
        self._default_target = Target(confidence=-1.0)
        self.segmentation_color = (0, 200, 0)
        self.bounding_box_color = (0, 255, 0)
        self.text_color = (36, 255, 12)

    def detect(self, frame: np.ndarray) -> Results:
        return self.model.predict(frame, stream=False, verbose=False)[0]

    def find_best_target_box(self, results: Results) -> Target:
        """
        Extract the most confident detection from the YOLO results.

        :param results: YOLO prediction.
        :return: Target dataclass with detection info.
        """
        boxes = results.boxes
        if not (boxes and boxes.conf is not None and len(boxes.conf) > 0):
            return self._default_target

        best_conf_index = boxes.conf.argmax()
        best_conf = boxes.conf[best_conf_index].item()
        coords = boxes.xyxy[best_conf_index].int().tolist()
        x1, y1, x2, y2 = coords

        return Target(
            confidence=best_conf,
            center=((x1 + x2) // 2, (y1 + y2) // 2),
            size=(x2 - x1, y2 - y1),
            box=(x1, y1, x2, y2),
            is_lost=False
        )

    def segment_image(
            self, frame: np.ndarray, are_results_returned: bool = False
    ) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Results]:
        """
        Runs segmentation on an image using YOLO and returns both the annotated image
        and the first binary mask.

        :param frame: Image to be segmented.
        :param are_results_returned: Will also return YOLO results.
        :return: Tuple of (annotated_frame, default_mask)
        """
        results = self.detect(frame)
        frame_height, frame_width = frame.shape[:2]
        default_mask = np.zeros((frame_height, frame_width), dtype=np.uint8)

        if results.masks is None or len(results.masks.data) == 0:
            return (frame, default_mask, results) if are_results_returned else (frame, default_mask)

        annotated_frame = np.array(frame)
        masks = []

        for i, mask in enumerate(results.masks.data):
            binary_mask = self._process_mask(mask, frame_width, frame_height)
            masks.append(binary_mask)

            overlay = np.zeros_like(frame)
            mask_bool = binary_mask > 0
            for c in range(3):
                overlay[:, :, c][mask_bool] = self.segmentation_color[c]
            annotated_frame = cv2.addWeighted(annotated_frame, self.alpha, overlay, 1 - self.alpha, 0, overlay)

        self._draw_boxes(annotated_frame, results)

        output = (annotated_frame, masks[0] if masks else default_mask)
        return (*output, results) if are_results_returned else output

    def segment_image_v2(self, frame: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        frame = np.array(frame)
        results = self.model.predict(frame, conf=0.25, iou=0.7)
        frame_height, frame_width = frame.shape[:2]
        default_mask = np.zeros((frame_height, frame_width), dtype=np.uint8)

        masks = []
        for result in results:
            if result.masks is None or len(result.masks.data) == 0:
                continue
            binary_mask = self._process_mask(result.masks.data[0], frame_width, frame_height)
            masks.append(binary_mask)

            overlay = frame.copy()
            overlay[binary_mask > 0] = [0, 255, 0]
            cv2.addWeighted(overlay, self.alpha, frame, 1 - self.alpha, 0, frame)

            self._draw_boxes(frame, result)

        return frame, (masks[0] if masks else default_mask)

    def annotate_frame(self, frame: np.ndarray) -> np.ndarray:
        frame = np.array(frame)
        results = self.detect(frame)

        if results.boxes and results.boxes.conf is not None and len(results.boxes.conf) > 0:
            best_index = results.boxes.conf.argmax()
            best_confidence = results.boxes.conf[best_index].item()

            if best_confidence >= self.confidence_threshold:
                xy_seg = results.masks.xy[best_index]
                xy_seg = [list(xy) for xy in xy_seg]
                xy_seg = np.array(xy_seg).astype(np.int32)
                cv2.fillPoly(frame, pts=[xy_seg], color=self.segmentation_color)

        return frame

    def _draw_boxes(self, frame: np.ndarray, results: Results):
        """Draw bounding boxes and class labels on the frame."""
        if results.boxes is None or len(results.boxes) == 0:
            return

        for box, cls_id, conf in zip(results.boxes.xyxy, results.boxes.cls, results.boxes.conf):
            x1, y1, x2, y2 = map(int, box)
            label = f"{results.names[int(cls_id)]}: {conf:.2f}"

            cv2.rectangle(frame, (x1, y1), (x2, y2), self.bounding_box_color, 2)
            cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.text_color, 2)

    def _process_mask(self, mask_tensor, width: int, height: int) -> np.ndarray:
        """Convert a YOLO mask tensor to a binary mask (uint8)."""
        mask = mask_tensor.cpu().numpy()
        mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
        return (mask > self.mask_confidence).astype(np.uint8) * 255


def compare_yolo_models(frame: np.ndarray, model1: YoloSegmentation, model2: YoloSegmentation):
    start_time = time.perf_counter()
    segmentation, mask = model1.segment_image(frame)
    end_time = time.perf_counter()
    print(f"Finished segmentation in: {end_time - start_time} seconds")

    start_time = time.perf_counter()
    segmentation_v2, mask_v2 = model2.segment_image(frame)
    end_time = time.perf_counter()
    print(f"Finished segmentation in: {end_time - start_time} seconds")

    print(f"Shape mask: {mask.shape} - {mask_v2.shape}")
    cv2.imshow("frame", frame)
    cv2.imshow("frame_seg_v1 vs frame_seg_v2", np.hstack((segmentation, segmentation_v2)))
    cv2.imshow("mask_v1 vs mask_v2", np.hstack((mask, mask_v2)))
    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == '__main__':
    import time
    from path_manager import get_path_of_files

    yolo_model = YoloSegmentation(model_path=Paths.YOLO_SEGMENTATION_MODEL_PATH)
    yolo_model_v2 = YoloSegmentation(model_path=Paths.YOLO_SEGMENTATION_MODEL_V2_PATH)

    paths = get_path_of_files(directory=Paths.TEST_IMAGES_DIR, file_type=".png")

    for path in paths:
        f = cv2.imread(path)
        compare_yolo_models(f, yolo_model, yolo_model_v2)
