import csv
import json
import os

import cv2 as cv
import numpy as np

from tracking.visualisation import compute_bounding_box


class DetectedObject:
    max_num_objects = 0

    def __init__(self, frame_id, keypoints, detection_score, bounding_box):
        self.frame_id = frame_id
        self.keypoints = keypoints
        self.bounding_box = bounding_box
        self.detection_score = detection_score
        self.object_id = None

    def start_tracking_object(self):
        DetectedObject.max_num_objects += 1
        self.object_id = DetectedObject.max_num_objects


def load_pose_data(pose_data_path, video_resolution, keypoint_confidence_threshold=0.0, min_keypoints=9):
    """Load pose data from a .json file and compute their bounding boxes.

    Argument(s):
        pose_path: Path to the .json file containing the pose data. Each record in the .json file contains the
            frame id (key: 'image_id'); the keypoints as a list of [x_1, y_1, c_1, ..., x_k, y_k, c_k], where x_k and
            y_k are the coordinates and c_k is the confidence score for the detection of the k-th keypoint
            (key: 'keypoints'); and the overall score of the detected keypoints (key: 'score').
        video_resolution: A list containing the width and the height of the video from where the detected objects
            came from.
        keypoint_confidence_threshold: Any keypoint with a detected confidence below this threshold is considered to
            be missing.
        min_keypoints: Any detection with less than min_keypoints keypoints detected is ignored.

    Return(s):
            A dictionary mapping the ids of the frames to a list of detections in those frames.
    """
    frame_width, frame_height = video_resolution
    pose_data = {}
    with open(pose_data_path, mode='r') as f:
        json_pose_data = json.load(f)
        for record in json_pose_data:
            frame_id = record['image_id']
            try:
                frame_id = int(frame_id)
            except ValueError:
                frame_id = int(record['image_id'].split(sep='.')[0])  # e.g. '0001.png' -> 1
            keypoints = np.array(record['keypoints'], dtype=np.float32).reshape(-1, 3)[:, :2]
            keypoints_confidences = np.array(record['keypoints'], dtype=np.float32).reshape(-1, 3)[:, 2]
            keypoints[keypoints_confidences < keypoint_confidence_threshold] = 0.0
            if np.sum(keypoints[:, 0] > 0) < min_keypoints:
                continue
            detection_score = record['score']
            bounding_box = compute_bounding_box(keypoints, frame_width=frame_width, frame_height=frame_height)

            pose_data.setdefault(frame_id, [])
            detection = DetectedObject(frame_id=frame_id, keypoints=keypoints, detection_score=detection_score,
                                       bounding_box=bounding_box)
            pose_data[frame_id].append(detection)

    return pose_data


def remove_poor_detections(pose_data, quantile=25):
    if quantile == 0:
        return pose_data

    detection_scores = []
    for frame_id, detections in pose_data.items():
        for detection in detections:
            detection_scores.append(detection.detection_score)
    detection_threshold = np.percentile(detection_scores, quantile)

    filtered_pose_data = {}
    for frame_id, detections in pose_data.items():
        filtered_pose_data.setdefault(frame_id, [])
        for detection in detections:
            if detection.detection_score > detection_threshold:
                filtered_pose_data[frame_id].append(detection)

    return filtered_pose_data


def fill_pose_data_with_empty_frames(pose_data, num_frames):
    for frame_id in range(num_frames):
        if frame_id not in pose_data.keys():
            pose_data.setdefault(frame_id, [])

    return pose_data


def write_tracking(file_name, pose_data):
    output = []
    frame_ids = sorted(pose_data.keys())
    for frame_id in frame_ids:
        detections = pose_data[frame_id]
        for detection in detections:
            object_id = detection.object_id
            kps = detection.keypoints.ravel()

            row = np.hstack(([frame_id, object_id], kps)).tolist()
            output.append(row)

    with open(file_name, mode='w') as f:
        writer = csv.writer(f)
        writer.writerows(output)

    return output


def write_trajectories(dir_name, trajectories_split_by_object):
    for object_id in trajectories_split_by_object.keys():
        file_name = '%.4d' % object_id
        file_path = os.path.join(dir_name, file_name) + '.csv'
        trajectory = trajectories_split_by_object[object_id]
        np.savetxt(file_path, X=trajectory, delimiter=',')


def split_and_pad(args):
    tracked_trajectories_file = args.tracked_trajectories
    write_split_trajectories_dir = args.write_split_trajectories

    tracked_trajectories = np.loadtxt(tracked_trajectories_file, dtype=np.float32, delimiter=',', ndmin=2)
    tracked_trajectories_split_by_object = _split_and_pad(tracked_trajectories)

    if write_split_trajectories_dir is not None:
        write_trajectories(write_split_trajectories_dir, tracked_trajectories_split_by_object)

    video_id = os.path.basename(tracked_trajectories_file).split('.')[0]  # e.g. .../01_001.csv -> 01_001
    print('Finished splitting the trajectories by object id and padding missing frames for video %s' % video_id)


def _split_and_pad(tracked_trajectories):
    object_ids = np.unique(tracked_trajectories[:, 1]).astype(np.int32)
    num_cols = tracked_trajectories.shape[1]
    padded_trajectories = {}
    for object_id in object_ids:
        mask = tracked_trajectories[:, 1] == object_id
        trajectory = tracked_trajectories[mask]

        frames = trajectory[:, 0].astype(np.int32)
        first_frame, last_frame = frames[[0, -1]]
        num_frames = last_frame - first_frame + 1
        padded_trajectory = np.zeros((num_frames, num_cols - 1), dtype=np.float32)
        padded_trajectory[:, 0] = np.arange(first_frame, last_frame + 1, dtype=np.float32)
        padded_trajectory[frames - first_frame, 1:] = trajectory[:, 2:]

        padded_trajectories[object_id] = padded_trajectory

    return padded_trajectories


def load_imgs(frame_ids, frames_path, frame_names):
    imgs = {}
    for frame_id in frame_ids:
        img = load_img(frame_id, frames_path, frame_names)
        imgs[frame_id] = img

    return imgs


def load_img(frame_id, frames_path, frame_names):
    img = cv.imread(os.path.join(frames_path, frame_names[frame_id]))
    img = cv.cvtColor(img, code=cv.COLOR_BGR2GRAY)

    return img
