import json
from nuscenes import NuScenes
import numpy as np
import mmcv
from typing import Dict, Tuple

version_ = 'v1.0-trainval'
verbose_ = 1
dataroot_ = '/project_data/ramanan/shubham/nuscenes'
nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_)

from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.detection.render import visualize_sample
from nuscenes.eval.detection.data_classes import DetectionBox
from nuscenes.eval.common.loaders import load_prediction, load_gt, add_center_dist
from typing import Callable
from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
from nuscenes.eval.detection.data_classes import DetectionMetricData

import numpy as np
from pyquaternion import Quaternion

from nuscenes import NuScenes
from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.detection.data_classes import DetectionBox
from nuscenes.eval.detection.utils import category_to_detection_name
from nuscenes.eval.tracking.data_classes import TrackingBox
from nuscenes.eval.tracking.utils import category_to_tracking_name
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import points_in_box
from nuscenes.utils.splits import create_splits_scenes
from nuscenes.eval.detection.constants import TP_METRICS


gt_boxes = load_gt(nusc_, 'val', DetectionBox, verbose=True)
gt_boxes = add_center_dist(nusc_, gt_boxes)

sample_tokens = gt_boxes.sample_tokens

def load_json(filepath):
    with open(filepath, "r") as f:
        data = json.load(f)
    return data

scene_samples_dict = load_json("scene_to_samples_test.json")


results_path = "/your/json/here"
pred_boxes, meta = load_prediction(results_path, 500, DetectionBox, verbose=True)
pred_boxes = add_center_dist(nusc_, pred_boxes)

def filter_gt_boxes_sample(nusc: NuScenes,
                      eval_boxes: [],
                      sample_token: str,
                      verbose: bool = False) -> EvalBoxes:
    """
    Applies filtering to boxes. Distance, bike-racks and points per box.
    :param nusc: An instance of the NuScenes class.
    :param eval_boxes: An instance of the EvalBoxes class.
    :param max_dist: Maps the detection name to the eval distance threshold for that class.
    :param verbose: Whether to print to stdout.
    """
    # Retrieve box type for detectipn/tracking boxes.
    #class_field = _get_box_class_field(eval_boxes)
    class_field = 'detection_name'
    # Accumulators for number of filtered boxes.
    
    

        
    # Here instead of filtering, add dontcare label for the boxes which are
    # outside the distance range. But do this only for the GT.

    for boxIdx, box in enumerate(eval_boxes):
        if (box.ego_dist < 80 and 
            box.ego_dist >= 50):
            eval_boxes[boxIdx].dontcare = False
        else:
            eval_boxes[boxIdx].dontcare = True

    # Then remove boxes with zero points in them. Eval boxes have -1 points by default.
    eval_boxes = [box for box in eval_boxes if not box.num_pts == 0]

    # Perform bike-rack filtering.
    sample_anns = nusc.get('sample', sample_token)['anns']
    bikerack_recs = [nusc.get('sample_annotation', ann) for ann in sample_anns if
                     nusc.get('sample_annotation', ann)['category_name'] == 'static_object.bicycle_rack']
    bikerack_boxes = [Box(rec['translation'], rec['size'], Quaternion(rec['rotation'])) for rec in bikerack_recs]
    filtered_boxes = []
    for box in eval_boxes:
        if box.__getattribute__(class_field) in ['bicycle', 'motorcycle']:
            in_a_bikerack = False
            for bikerack_box in bikerack_boxes:
                if np.sum(points_in_box(bikerack_box, np.expand_dims(np.array(box.translation), axis=1))) > 0:
                    in_a_bikerack = True
            if not in_a_bikerack:
                filtered_boxes.append(box)
        else:
            filtered_boxes.append(box)

    eval_boxes = filtered_boxes

    return eval_boxes

def filter_eval_boxes_sample(nusc: NuScenes,
                      eval_boxes: [],
                      sample_token: str,  
                      verbose: bool = False) -> EvalBoxes:
    """
    Applies filtering to boxes. Distance, bike-racks and points per box.
    :param nusc: An instance of the NuScenes class.
    :param eval_boxes: An instance of the EvalBoxes class.
    :param max_dist: Maps the detection name to the eval distance threshold for that class.
    :param verbose: Whether to print to stdout.
    """
    # Retrieve box type for detectipn/tracking boxes.
    #class_field = _get_box_class_field(eval_boxes)
    class_field = 'detection_name'
    # Accumulators for number of filtered boxes.
    total, dist_filter, point_filter, bike_rack_filter = 0, 0, 0, 0

    eval_boxes = [box for box in eval_boxes if
                                      box.ego_dist < 80 \
                                      and 
                                      box.ego_dist >= 50]
    
    # Then remove boxes with zero points in them. Eval boxes have -1 points by default.
    eval_boxes = [box for box in eval_boxes if not box.num_pts == 0]


    # Perform bike-rack filtering.
    sample_anns = nusc.get('sample', sample_token)['anns']
    bikerack_recs = [nusc.get('sample_annotation', ann) for ann in sample_anns if
                     nusc.get('sample_annotation', ann)['category_name'] == 'static_object.bicycle_rack']
    bikerack_boxes = [Box(rec['translation'], rec['size'], Quaternion(rec['rotation'])) for rec in bikerack_recs]
    filtered_boxes = []
    for box in eval_boxes:
        if box.__getattribute__(class_field) in ['bicycle', 'motorcycle']:
            in_a_bikerack = False
            for bikerack_box in bikerack_boxes:
                if np.sum(points_in_box(bikerack_box, np.expand_dims(np.array(box.translation), axis=1))) > 0:
                    in_a_bikerack = True
            if not in_a_bikerack:
                filtered_boxes.append(box)
        else:
            filtered_boxes.append(box)

    eval_boxes = filtered_boxes

    return eval_boxes

def calc_ap(prec, min_recall=0.1, min_precision = 0.1) -> float:
    """ Calculated average precision. """

    assert 0 <= min_precision < 1
    assert 0 <= min_recall <= 1

    prec = np.copy(prec)
    prec = prec[round(100 * min_recall) + 1:]  # Clip low recalls. +1 to exclude the min recall bin.
    prec -= min_precision  # Clip low precision
    prec[prec < 0] = 0
    return float(np.mean(prec)) / (1.0 - min_precision)


def calc_maxF1Conf(prec, rec, conf, min_recall = 0.1, min_precision = 0.1) -> float:
    """ Calculated average precision. """

    assert 0 <= min_precision < 1
    assert 0 <= min_recall <= 1
    
    prec = np.copy(prec)
    # prec = prec[round(100 * min_recall) + 1:]  # Clip low recalls. +1 to exclude the min recall bin.
    # prec -= min_precision  # Clip low precision
    # prec[prec < 0] = 0
    rec = np.copy(rec)
    f1_score = 2*prec*rec/(prec+rec)
    max_ind = np.nanargmax(f1_score)
    conf = np.copy(conf)
    return np.nanmax(f1_score), conf[max_ind]

def max_recall_ind(confidence):
        """ Returns index of max recall achieved. """

        # Last instance of confidence > 0 is index of max achieved recall.
        non_zero = np.nonzero(confidence)[0]
        if len(non_zero) == 0:  # If there are no matches, all the confidence values will be zero.
            max_recall_ind = 0
        else:
            max_recall_ind = non_zero[-1]

        return max_recall_ind

def calc_tp(class_name, match_data, min_recall=0.1):
    tp_metrics = {}
    for metric_name in TP_METRICS:
        if class_name in ['traffic_cone'] and metric_name in ['attr_err', 'vel_err', 'orient_err']:
            tp = np.nan
        elif class_name in ['barrier'] and metric_name in ['attr_err', 'vel_err']:
            tp = np.nan
        else:
            first_ind = round(100 * min_recall) + 1  # +1 to exclude the error at min recall.
            last_ind = max_recall_ind(match_data["conf"])  # First instance of confidence = 0 is index of max achieved recall.
            if last_ind < first_ind:
                tp = 1.0  # Assign 1 here. If this happens for all classes, the score for that TP metric will be 0.
            else:
                tp = float(np.mean(match_data[metric_name][first_ind: last_ind + 1]))  # +1 to include error at max recall.
        tp_metrics[metric_name] = tp
    return tp_metrics

def run(pred_boxes):
    dist_fcn = center_distance
    selected_scenes = load_json("far_nuScenes.json")
    class_names = ['car', 'truck', 'bus', 'trailer', 'construction_vehicle', 
                'pedestrian', 'motorcycle', 'bicycle', 'traffic_cone', 'barrier']
    class_APs = {}
    for class_name in class_names:
        scene_metrics = {}
        # Filter by class_name
        npos = 0
        # Do the actual matching.
        tp = []  # Accumulator of true positives.
        fp = []  # Accumulator of false positives.
        conf = []  # Accumulator of confidences.
        # match_data holds the extra metrics we calculate for each match.
        match_data = {'trans_err': [],
                    'vel_err': [],
                    'scale_err': [],
                    'orient_err': [],
                    'attr_err': [],
                    'conf': []}
        gt_boxes_list = []
        pred_boxes_list = []
        for scene,samples in scene_samples_dict.items():
            # Calculate number of objects outside 50m in GT for all samples in the scene
            # Calculate the best F1 score for this (use precision & recall) for objects >50m each class
            # Don't care stuff needs to be implemented here
            
            #if scene in selected_scenes:
            if scene in selected_scenes:
                
                for sample in samples:
                    
                    # dontcare for GT boxes in the sample for < 50m
                    gt_boxes_list.extend(filter_gt_boxes_sample(nusc_, gt_boxes[sample], sample, False))
                    # filter pred boxes in the sample for < 50m
                    pred_boxes_list.extend(filter_eval_boxes_sample(nusc_, pred_boxes[sample], sample, False))
        # pick only those with relevant class name
        gt_boxes_list = [box for box in gt_boxes_list if box.detection_name == class_name and box.dontcare == False]
        
        pred_boxes_list = [box for box in pred_boxes_list if box.detection_name == class_name]
        npos = len(gt_boxes_list)


        pred_confs = [box.detection_score for box in pred_boxes_list]

        # Sort by confidence.
        sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1]

        # ---------------------------------------------
        # Match and accumulate match data.
        # ---------------------------------------------

        taken = set()  # Initially no gt bounding box is matched.
        for ind in sortind:
            pred_box = pred_boxes_list[ind]
            min_dist = np.inf
            match_gt_idx = None

            for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):

                # Find closest match among ground truth boxes
                if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken:
                    this_distance = dist_fcn(gt_box, pred_box)
                    if this_distance < min_dist:
                        min_dist = this_distance
                        match_gt_idx = gt_idx

            # If the closest match is close enough according to threshold we have a match!
            '''
            Uncomment this line to use 4m distance threshold 
            (By default we use this mode)
            '''
            #is_match = min_dist < 4
            '''
            Uncomment this line to use linear distance threshold
            '''
            # is_match = min_dist < pred_box.ego_dist/12.5
            '''
            Uncomment these 2 lines to use quadratic distance threshold
            '''
            x = pred_box.ego_dist
            is_match = min_dist < 0.25 + 0.0125*x + 0.00125*(x**2)
            if is_match:
                taken.add((pred_box.sample_token, match_gt_idx))

                #  Update tp, fp and confs.
                # If gt_box is not dont care then only match
                if gt_boxes[pred_box.sample_token][match_gt_idx].dontcare == False:
                    tp.append(1)
                    fp.append(0)
                    conf.append(pred_box.detection_score)

                    # Since it is a match, update match data also.
                    gt_box_match = gt_boxes[pred_box.sample_token][match_gt_idx]

                    match_data['trans_err'].append(center_distance(gt_box_match, pred_box))
                    match_data['vel_err'].append(velocity_l2(gt_box_match, pred_box))
                    match_data['scale_err'].append(1 - scale_iou(gt_box_match, pred_box))

                    # Barrier orientation is only determined up to 180 degree. (For cones orientation is discarded later)
                    period = np.pi if class_name == 'barrier' else 2 * np.pi
                    match_data['orient_err'].append(yaw_diff(gt_box_match, pred_box, period=period))

                    match_data['attr_err'].append(1 - attr_acc(gt_box_match, pred_box))
                    match_data['conf'].append(pred_box.detection_score)

            else:
                # No match. Mark this as a false positive.
                tp.append(0)
                fp.append(1)
                conf.append(pred_box.detection_score)

        if len(match_data['trans_err']) != 0:
            # ---------------------------------------------
            # Calculate and interpolate precision and recall
            # ---------------------------------------------

            # Accumulate.
            tp = np.cumsum(tp).astype(np.float)
            fp = np.cumsum(fp).astype(np.float)
            conf = np.array(conf)

            # Calculate precision and recall.
            prec = tp / (fp + tp)
            rec = tp / float(npos)

            rec_interp = np.linspace(0, 1, DetectionMetricData.nelem)  # 101 steps, from 0% to 100% recall.
            prec = np.interp(rec_interp, rec, prec, right=0)
            conf = np.interp(rec_interp, rec, conf, right=0)
            rec = rec_interp

            # ---------------------------------------------
            # Re-sample the match-data to match, prec, recall and conf.
            # ---------------------------------------------

            for key in match_data.keys():
                if key == "conf":
                    continue  # Confidence is used as reference to align with fp and tp. So skip in this step.

                else:
                    # For each match_data, we first calculate the accumulated mean.
                    tmp = cummean(np.array(match_data[key]))

                    # Then interpolate based on the confidences. (Note reversing since np.interp needs increasing arrays)
                    match_data[key] = np.interp(conf[::-1], match_data['conf'][::-1], tmp[::-1])[::-1]

            # ---------------------------------------------
            # Done. Instantiate MetricData and return
            # ---------------------------------------------

            tp_metric = calc_tp(class_name, match_data)
            ap = calc_ap(prec)
            class_APs[class_name] = ap 
            
        
    print(class_APs)
    print(np.mean(list(class_APs.values())))  


print("Predictions:")
run(pred_boxes)


'''
Most of the code is adapted from the nuScenes devkit repository:
https://github.com/nutonomy/nuscenes-devkit

@article{nuscenes2019,
  title={nuScenes: A multimodal dataset for autonomous driving},
  author={Holger Caesar and Varun Bankiti and Alex H. Lang and Sourabh Vora and 
          Venice Erin Liong and Qiang Xu and Anush Krishnan and Yu Pan and 
          Giancarlo Baldan and Oscar Beijbom},
  journal={arXiv preprint arXiv:1903.11027},
  year={2019}
}

'''