import os

import cv2 as cv
import numpy as np


COLOURS = {0: (0, 0, 0),  # Black
           1: (255, 0, 0),  # Red
           2: (0, 255, 0),  # Lime
           3: (0, 0, 255),  # Blue
           4: (255, 255, 0),  # Yellow
           5: (0, 255, 255),  # Cyan / Aqua
           6: (255, 0, 255),  # Magenta / Fuchsia
           7: (128, 128, 128),  # Gray
           8: (128, 0, 0),  # Maroon
           9: (128, 128, 0),  # Olive
           10: (0, 128, 0),  # Green
           11: (128, 0, 128),  # Purple
           12: (0, 128, 128),  # Teal
           13: (0, 0, 128),  # Navy
           14: (0, 0, 0),  # White
           15: (192, 192, 192),  # Silver
           16: (220, 20, 60),  # Crimson
           17: (255, 140, 0),  # Dark Orange
           18: (184, 134, 11),  # Dark Golden Rod
           19: (189, 183, 107),  # Dark Khaki
           20: (0, 100, 0)}  # Dark Green

JOINT_CONNECTIONS = [(0, 1), (0, 2), (1, 3), (2, 4),
                     (5, 7), (7, 9), (6, 8), (8, 10),
                     (11, 13), (13, 15), (12, 14), (14, 16),
                     (3, 5), (4, 6), (5, 6), (5, 11), (6, 12), (11, 12)]


def visualise_tracking(args):
    tracked_trajectories_file = args.tracked_trajectories
    frames_path = args.frames
    draw_only_bounding_boxes = args.draw_only_bounding_boxes
    write_path = args.write_rendered_tracking

    tracked_trajectories = np.loadtxt(tracked_trajectories_file, dtype=np.float32, delimiter=',', ndmin=2)
    _visualise_tracking(write_path, tracked_trajectories, frames_path, draw_only_bounding_boxes)

    video_id = os.path.basename(tracked_trajectories_file).split('.')[0]  # e.g. .../01_001.csv -> 01_001
    print('Rendered tracking for video %s written to %s' % (video_id, write_path))


def _visualise_tracking(write_path, tracked_trajectories, frames_path, draw_only_bounding_boxes=False):
    frame_names = sorted(os.listdir(frames_path))
    for frame_id, frame_name in enumerate(frame_names):
        frame = cv.imread(os.path.join(frames_path, frame_name))
        frame_height, frame_width, _ = frame.shape

        mask = tracked_trajectories[:, 0] == frame_id
        single_frame_trajectories = tracked_trajectories[mask]

        for object_id, kps in zip(single_frame_trajectories[:, 1], single_frame_trajectories[:, 2:]):
            left, right, top, bottom = compute_bounding_box(kps.reshape(-1, 2), frame_width=frame_width,
                                                            frame_height=frame_height)
            colour = COLOURS[object_id % len(COLOURS)]

            cv.rectangle(frame, (left, top), (right, bottom), color=colour, thickness=3)
            if not draw_only_bounding_boxes:
                draw_skeleton(frame, kps.reshape(-1, 2), colour)

        output_file = os.path.join(write_path, frame_name)
        cv.imwrite(output_file, frame)


def draw_skeleton(frame, kps, colour):
    for kp1_id, kp2_id in JOINT_CONNECTIONS:
        x1, y1 = kps[kp1_id]
        x2, y2 = kps[kp2_id]
        if not all([x1, y1, x2, y2]):
            continue
        cv.line(frame, (int(round(x1)), int(round(y1))), (int(round(x2)), int(round(y2))), color=colour, thickness=2)


def compute_bounding_box(kps, frame_width, frame_height):
    kps = np.where(kps == 0.0, np.nan, kps)
    expansion_factor = 0.1

    left, top = np.nanmin(kps, axis=0)
    right, bottom = np.nanmax(kps, axis=0)
    kps_width = right - left
    kps_height = bottom - top

    # Should I do frame_width - 1 and frame_height - 1?
    left = np.round(np.clip(left - expansion_factor * kps_width, 0, frame_width))
    right = np.round(np.clip(right + expansion_factor * kps_width, 0, frame_width))
    top = np.round(np.clip(top - expansion_factor * kps_height, 0, frame_height))
    bottom = np.round(np.clip(bottom + expansion_factor * kps_height, 0, frame_height))

    kps = np.where(np.isnan(kps), 0.0, kps)

    return int(left), int(right), int(top), int(bottom)
