import cv2
import numpy as np

from yolo_model import YoloSegmentation


class CarMaskSplitter:
    def __init__(self):
        self.front_color = [220, 40, 255]
        self.back_color = [220, 255, 40]
        self.text_color = [255, 255, 255]
        self.ESCAPE_KEY = 27
        self.ENTER_KEY = 13
        self.choose_instructions = [
            "Click to select front of the car.",
            "Press 'K' or Enter to confirm.",
            "Press 'Q' or ESC to cancel and skip.",
            "Purple Mask = Front | Cyan Mask = Back"
        ]
        self.confirm_instructions = [
            "Purple Mask = Front | Cyan Mask = Back",
            "Press 'K' or Enter to confirm.",
            "Press 'R' to Redo.",
        ]
        self.tolerance = 1e-8

    def infer_annotate(self, model: YoloSegmentation, image: np.ndarray, frame_name: str = ""):
        _, mask = model.segment_image(image)
        return self._annotate_core(image, mask, frame_name)

    def annotate(self, image: np.ndarray, mask: np.ndarray, frame_name: str = "") -> tuple[np.ndarray, np.ndarray]:
        return self._annotate_core(image, mask, frame_name)

    def _annotate_core(self, image: np.ndarray, mask: np.ndarray, frame_name: str = ""):
        is_image_done = False
        front_mask = np.zeros(mask.shape, dtype=np.uint8)
        back_mask = np.zeros(mask.shape, dtype=np.uint8)

        while not is_image_done:
            front_point = self.get_user_click(image, mask, frame_name)
            if front_point is None:
                print("No points selected, skipping image.")
                is_image_done = True
                continue
            front_mask, back_mask = self.geometric_split_mask(mask, front_point)
            display_image = self.draw_instructions(
                self.overlay(image, front_mask, back_mask), self.confirm_instructions
            )
            user_confirmed = False
            while not user_confirmed:
                cv2.imshow("Overlay", display_image)
                key = cv2.waitKey(1)
                if key == ord("r"):
                    print("Reset Requested by user.")
                    user_confirmed = True
                    continue
                elif key in {ord("k"), self.ENTER_KEY}:
                    user_confirmed = True
                    is_image_done = True
                    continue

        cv2.destroyAllWindows()
        return front_mask, back_mask

    def geometric_split_mask(self, mask: np.ndarray, front_point: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]:
        """
        Splits the car mask into front and back parts based on the user front_point coordinates.
        """
        mask_points = np.column_stack(np.where(mask > 0))

        M = cv2.moments(mask)
        if M["m00"] == 0:
            return mask, np.zeros_like(mask)
        center_x = int(M["m10"] / M["m00"])
        center_y = int(M["m01"] / M["m00"])
        center = np.array([center_y, center_x])

        if mask[front_point[1], front_point[0]] == 0:
            print("Warning: Front point not within mask, finding closest mask point")
            distances = np.linalg.norm(mask_points - np.array([front_point[1], front_point[0]]), axis=1)
            closest_idx = np.argmin(distances)
            front_point = (mask_points[closest_idx][1], mask_points[closest_idx][0])
            print(f"Using closest point: {front_point}")

        front = np.array([front_point[1], front_point[0]])
        direction = front - center
        direction_unit = direction / (np.linalg.norm(direction) + self.tolerance)

        front_mask = np.zeros_like(mask, dtype=np.uint8)
        back_mask = np.zeros_like(mask, dtype=np.uint8)

        for (y, x) in mask_points:
            point_vec = np.array([y, x]) - center
            dot_product = np.dot(point_vec, direction_unit)
            if dot_product > 0:
                front_mask[y, x] = 255
            else:
                back_mask[y, x] = 255

        return front_mask, back_mask

    def overlay(self, image: np.ndarray, front_mask: np.ndarray, back_mask: np.ndarray) -> np.ndarray:
        color_mask = np.zeros_like(image)
        color_mask[front_mask > 0] = self.front_color
        color_mask[back_mask > 0] = self.back_color
        return cv2.addWeighted(image, 0.7, color_mask, 0.3, 0)

    def get_user_click(self, image: np.ndarray, mask: np.ndarray, frame_name: str = "") -> tuple[int, int] | None:
        """
        Displays an image and lets the user click on it. Shows visual feedback for clicks.
        Returns the most recent clicked point when 'k' or Enter is pressed.
        """
        clicked_points = []
        display_image = self.draw_instructions(image, self.choose_instructions)

        def mouse_callback(event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                print(f"Selected point: {x}, {y}")
                clicked_points.append((x, y))

        frame_name = f" - {frame_name}"
        window_name = f"Click on the front of the car{frame_name}"
        cv2.namedWindow(window_name)
        cv2.setMouseCallback(window_name, mouse_callback)

        user_confirmed = False
        while not user_confirmed:
            user_image = np.array(display_image)
            if len(clicked_points) > 0:
                last_point = clicked_points[-1]
                front_mask, back_mask = self.geometric_split_mask(mask, last_point)
                user_image = self.overlay(user_image, front_mask, back_mask)
                cv2.circle(user_image, last_point, 5, (0, 0, 255), -1)

            cv2.imshow(window_name, user_image)

            key = cv2.waitKey(1)
            if key in {ord("q"), self.ESCAPE_KEY}:
                print("Cancelled by used.")
                clicked_points.clear()
                break
            elif key in {ord("k"), self.ENTER_KEY}:
                if len(clicked_points) > 0:
                    user_confirmed = True
                    continue
                print("You must select a point in the image before confirming.")

        cv2.destroyAllWindows()
        if len(clicked_points) > 0:
            return clicked_points[-1]

        print("No click detected.")
        return None

    def draw_instructions(self, image: np.ndarray, instructions: list[str]) -> np.ndarray:
        overlay = np.array(image)
        y0 = 20
        for i, text in enumerate(instructions):
            y = y0 + i * 25
            cv2.putText(overlay, text, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, self.text_color, 1, cv2.LINE_AA)
        return overlay


def run_example_test(images_paths: list[str], masks_paths: list[str], is_infer: bool = True):
    annotator = CarMaskSplitter()
    yolo_model = None
    if is_infer:
        yolo_model = YoloSegmentation(model_path=Paths.YOLO_SEGMENTATION_MODEL_V2_PATH)

    for index, (img_path, mask_path) in enumerate(zip(images_paths, masks_paths)):
        frame = cv2.imread(img_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        if is_infer:
            front_mask, back_mask = annotator.infer_annotate(model=yolo_model, image=frame)
        else:
            front_mask, back_mask = annotator.annotate(image=frame, mask=mask)
        cv2.imwrite(f"front_{index}.png", front_mask)
        cv2.imwrite(f"back_{index}.png", back_mask)


if __name__ == '__main__':
    from path_manager import Paths, get_path_of_files

    i_paths = get_path_of_files(directory=Paths.TEST_IMAGES_DIR, file_type=".png")
    m_paths = get_path_of_files(directory=Paths.TEST_MASKS_DIR, file_type=".png")
    run_example_test(images_paths=i_paths, masks_paths=m_paths, is_infer=False)
    run_example_test(images_paths=i_paths, masks_paths=m_paths, is_infer=True)
