import os

from tracking.data import load_pose_data, remove_poor_detections, fill_pose_data_with_empty_frames, write_tracking
from tracking.data import load_img
from tracking.utils import stack_all_detections_seen, best_matching_hungarian_opt_flow
from tracking.utils import best_matching_hungarian_bounding_box


def track(args):
    pose_data_path = args.pose_data
    frames_path = args.frames
    use_bounding_box_strategy = args.use_bounding_box_strategy
    video_resolution = [float(measurement) for measurement in args.video_resolution.split('x')]
    detection_quantile = args.detection_quantile
    keypoint_confidence_threshold = args.keypoint_confidence_threshold
    min_keypoints = args.minimum_keypoints
    look_back = args.look_back
    bb_edge_ratio = args.bounding_box_edge_ratio
    matching_threshold = args.matching_threshold
    num_top_keypoints = args.num_top_keypoints
    write_tracking_path = args.write_tracking

    num_frames = len(os.listdir(frames_path))

    pose_data = load_pose_data(pose_data_path, video_resolution=video_resolution,
                               keypoint_confidence_threshold=keypoint_confidence_threshold,
                               min_keypoints=min_keypoints)
    pose_data = remove_poor_detections(pose_data, quantile=detection_quantile)
    pose_data = fill_pose_data_with_empty_frames(pose_data, num_frames=num_frames)
    pose_data = _track(pose_data, frames_path, use_bounding_box_strategy=use_bounding_box_strategy, look_back=look_back,
                       bb_edge_ratio=bb_edge_ratio, matching_threshold=matching_threshold,
                       num_top_keypoints=num_top_keypoints)

    if write_tracking_path is not None:
        write_tracking(write_tracking_path, pose_data)

    print('Finished tracking video %s' % os.path.basename(frames_path))


def _track(pose_data, frames_path, use_bounding_box_strategy=False, look_back=5, bb_edge_ratio=0.2,
           matching_threshold=0.0, num_top_keypoints=7):
    is_first_detection = True

    frame_ids = sorted(pose_data.keys())  # 0, 1, ..., num_frames - 1
    frame_names = sorted(os.listdir(frames_path))

    frames_seen = {}

    for frame_id in frame_ids[:-1]:
        next_frame_id = frame_id + 1

        if is_first_detection:
            if not pose_data[frame_id]:  # No detections in this frame
                continue
            is_first_detection = False
            for detection in pose_data[frame_id]:
                detection.start_tracking_object()

        if not pose_data[next_frame_id]:  # No detections in the next frame -> nothing to track
            continue

        last_frame_idx = max(frame_id - look_back, -1) + 1
        look_back_frame_ids = list(reversed(frame_ids[last_frame_idx:(frame_id + 1)]))
        max_num_objects = pose_data[next_frame_id][0].max_num_objects
        all_detections_seen, detection_seen_is_from_current_frame = \
            stack_all_detections_seen(pose_data,
                                      frame_ids=look_back_frame_ids,
                                      max_num_objects=max_num_objects)

        # No detections in the past look_back frames -> all detections in the next frame are new
        if not all_detections_seen:
            for detection in pose_data[next_frame_id]:
                if detection.object_id is None:
                    detection.start_tracking_object()
            continue

        if use_bounding_box_strategy:
            best_matching_indexes, matching_scores = best_matching_hungarian_bounding_box(pose_data[next_frame_id],
                                                                                          all_detections_seen,
                                                                                          bb_edge_ratio,
                                                                                          num_top_keypoints)
        else:
            next_frame = load_img(next_frame_id, frames_path, frame_names)
            best_matching_indexes, matching_scores = best_matching_hungarian_opt_flow(pose_data[next_frame_id],
                                                                                      all_detections_seen,
                                                                                      frames_seen, next_frame,
                                                                                      frames_path, frame_names,
                                                                                      bb_edge_ratio)
            frames_seen.clear()  # Avoid memory issues
            frames_seen[next_frame_id] = next_frame  # Cache the current frame of the next iteration

        for detection1_idx, detection2_idx in best_matching_indexes:
            if matching_scores[detection1_idx, detection2_idx] > matching_threshold:
                pose_data[next_frame_id][detection2_idx].object_id = all_detections_seen[detection1_idx].object_id

        for detection in pose_data[next_frame_id]:
            if detection.object_id is None:
                detection.start_tracking_object()

    return pose_data
