import heapq

import cv2 as cv
from munkres import Munkres
import numpy as np

from tracking.data import load_img


def stack_all_detections_seen(pose_data, frame_ids, max_num_objects):
    all_objects_seen = []
    object_is_from_current_frame = []
    all_object_ids = [(i + 1) for i in range(max_num_objects)]

    for frame_id in frame_ids:
        for detection in pose_data[frame_id]:
            if not all_object_ids:
                return all_objects_seen, object_is_from_current_frame

            if detection.object_id in all_object_ids:
                all_object_ids.remove(detection.object_id)
                all_objects_seen.append(detection)
                object_is_from_current_frame.append(frame_id == frame_ids[0])

    return all_objects_seen, object_is_from_current_frame


def compute_bounding_boxes_iou(bb1, bb2):
    """Compute area of the intersection over union between two bounding boxes.

    Argument(s):
        bb1: A list containing the minimum and maximum values for the x and y coordinates of the bounding box. The
            order is x_min, x_max, y_min, y_max.
        bb2: Same as bb1.

    Return(s):
        The area of the intersection over union between the input bounding boxes.
    """
    eps = 1e-5
    x1_min, x1_max, y1_min, y1_max = bb1
    x2_min, x2_max, y2_min, y2_max = bb2

    x_min, x_max = max(x1_min, x2_min), min(x1_max, x2_max)
    y_min, y_max = max(y1_min, y2_min), min(y1_max, y2_max)

    if x_min < x_max and y_min < y_max:
        inter_area = (x_max - x_min + 1) * (y_max - y_min + 1)
        bb1_area = (x1_max - x1_min + 1) * (y1_max - y1_min + 1)
        bb2_area = (x2_max - x2_min + 1) * (y2_max - y2_min + 1)
        iou = inter_area / (bb1_area + bb2_area - inter_area + eps)
    else:
        iou = 0.0

    return iou


def compute_detections_keypoints_iou(kps1, kps2, bb_edge_width, bb_edge_height, num_top_keypoints):
    kps_iou = []
    for kp1, kp2 in zip(kps1, kps2):
        x1, y1 = kp1
        x2, y2 = kp2

        if not all([x1, y1, x2, y2]):
            continue

        half_edge_width = bb_edge_width / 2
        half_edge_height = bb_edge_height / 2
        kp1_bb = [x1 - half_edge_width, x1 + half_edge_width, y1 - half_edge_height, y1 + half_edge_height]
        kp2_bb = [x2 - half_edge_width, x2 + half_edge_width, y2 - half_edge_height, y2 + half_edge_height]
        kps_iou.append(compute_bounding_boxes_iou(kp1_bb, kp2_bb))

    if not kps_iou:
        return 0.0

    return np.mean(heapq.nlargest(num_top_keypoints, kps_iou))


def compute_sparse_opt_flow(past_frame, next_frame, kps_past_frame, win_size=(15, 15), max_level=2,
                            criteria=(cv.TERM_CRITERIA_EPS | cv.TERM_CRITERIA_COUNT, 10, 0.03)):
    kps1_reshaped = kps_past_frame.reshape(-1, 1, 2).astype(np.float32)
    kps_next_frame, _, _ = cv.calcOpticalFlowPyrLK(prevImg=past_frame, nextImg=next_frame,
                                                   prevPts=kps1_reshaped, nextPts=None,
                                                   winSize=win_size, maxLevel=max_level, criteria=criteria)
    kps_next_frame = kps_next_frame.reshape(-1, 2)

    return kps_next_frame


def compute_keypoints_proportion_inside_bounding_box(kps, bb):
    x_min, x_max, y_min, y_max = bb
    kps_inside, num_valid_kps = 0, 0
    for x, y in kps:
        if not all([x, y]):
            continue

        if x_min <= x <= x_max and y_min <= y <= y_max:
            kps_inside += 1

        num_valid_kps += 1

    return kps_inside / num_valid_kps


def compute_keypoints_proportion_inside_keypoints_bounding_boxes(kps_opt_flow, kps_next_frame, bb_width, bb_height):
    kps_inside, num_valid_kps = 0, 0
    for (x_of, y_of), (x_nf, y_nf) in zip(kps_opt_flow, kps_next_frame):
        if not all([x_of, y_of, x_nf, y_nf]):
            continue

        half_width = bb_width / 2
        half_height = bb_height / 2
        kp_next_bb = [x_nf - half_width, x_nf + half_width, y_nf - half_height, y_nf + half_height]
        x_min, x_max, y_min, y_max = kp_next_bb

        if x_min <= x_of <= x_max and y_min <= y_of <= y_max:
            kps_inside += 1

        num_valid_kps += 1

    if num_valid_kps == 0:
        return 0.0

    return kps_inside / num_valid_kps


def best_matching_hungarian_opt_flow(pose_data_next_frame, all_detections_seen, frames_seen, next_frame,
                                     frames_path, frame_names, bb_edge_ratio):
    num_detections_seen = len(all_detections_seen)
    num_detections_next_frame = len(pose_data_next_frame)
    cost_matrix = np.zeros((num_detections_seen, num_detections_next_frame), dtype=np.float32)

    for detection_seen_id, detection_seen in enumerate(all_detections_seen):
        keypoints_detection_seen = detection_seen.keypoints
        if detection_seen.frame_id not in frames_seen:
            frames_seen[detection_seen.frame_id] = load_img(detection_seen.frame_id, frames_path, frame_names)

        frame_seen = frames_seen[detection_seen.frame_id]
        for detection_next_frame_id, detection_next_frame in enumerate(pose_data_next_frame):
            keypoints_detection_next_frame = detection_next_frame.keypoints
            bounding_box_detection_next_frame = detection_next_frame.bounding_box
            left, right, top, bottom = bounding_box_detection_next_frame
            bb_width, bb_height = bb_edge_ratio * (right - left), bb_edge_ratio * (bottom - top)

            keypoints_opt_flow_next_frame = compute_sparse_opt_flow(frame_seen, next_frame,
                                                                    kps_past_frame=keypoints_detection_seen)
            opt_flow_keypoints_proportion_inside_bb = \
                compute_keypoints_proportion_inside_bounding_box(keypoints_opt_flow_next_frame,
                                                                 bounding_box_detection_next_frame)
            opt_flow_keypoints_proportion_inside_keypoints_bbs = \
                compute_keypoints_proportion_inside_keypoints_bounding_boxes(keypoints_opt_flow_next_frame,
                                                                             keypoints_detection_next_frame,
                                                                             bb_width, bb_height)
            metrics = [opt_flow_keypoints_proportion_inside_bb, opt_flow_keypoints_proportion_inside_keypoints_bbs]

            matching_score = sum(metrics) / 2
            cost_matrix[detection_seen_id, detection_next_frame_id] = matching_score

    m = Munkres()
    best_matching_indexes = m.compute((-np.array(cost_matrix)).tolist())

    return best_matching_indexes, cost_matrix


def best_matching_hungarian_bounding_box(pose_data_next_frame, all_detections_seen, bb_edge_ratio, num_top_keypoints):
    num_detections_seen = len(all_detections_seen)
    num_detections_next_frame = len(pose_data_next_frame)
    cost_matrix = np.zeros((num_detections_seen, num_detections_next_frame), dtype=np.float32)

    for detection_seen_id, detection_seen in enumerate(all_detections_seen):
        keypoints_detection_seen = detection_seen.keypoints
        bounding_box_detection_seen = detection_seen.bounding_box

        for detection_next_frame_id, detection_next_frame in enumerate(pose_data_next_frame):
            keypoints_detection_next_frame = detection_next_frame.keypoints
            bounding_box_detection_next_frame = detection_next_frame.bounding_box
            left, right, top, bottom = bounding_box_detection_next_frame
            width, height = bb_edge_ratio * (right - left), bb_edge_ratio * (bottom - top)

            detections_bounding_boxes_iou = compute_bounding_boxes_iou(bounding_box_detection_seen,
                                                                       bounding_box_detection_next_frame)
            detections_keypoints_iou = compute_detections_keypoints_iou(keypoints_detection_seen,
                                                                        keypoints_detection_next_frame,
                                                                        width, height, num_top_keypoints)

            metrics = [detections_bounding_boxes_iou, detections_keypoints_iou]

            matching_score = sum(metrics) / 2
            cost_matrix[detection_seen_id, detection_next_frame_id] = matching_score

    m = Munkres()
    best_matching_indexes = m.compute((-np.array(cost_matrix)).tolist())

    return best_matching_indexes, cost_matrix
