import torch
import numpy as np
import json
import pickle
from collections import defaultdict
import copy
import itertools
from copy import deepcopy
from functools import partial
from multiprocessing.pool import ThreadPool
from scipy.spatial import ConvexHull
from nuscenes.eval.detection.evaluate import DetectionEval
from nuscenes.eval.detection.data_classes import DetectionConfig

DETECTION_NAMES = ['car', 'truck', 'bus', 'trailer', 'construction_vehicle', 'pedestrian', 'motorcycle', 'bicycle',
                   'traffic_cone', 'barrier']



class CONFIG(object):
    def __init__(self, cp_path, fcos_path, dataroot, config_path, output_dir, eval_set = 'val', version = 'v1.0-trainval'):
        self.cp_path = cp_path
        self.fcos_path = fcos_path
        self.output_dir_ = output_dir
        self.eval_set_ = eval_set
        self.dataroot_ = dataroot
        self.version_ = version
        self.config_path =config_path
        self.plot_examples_ = 10
        self.render_curves_ = True
        self.verbose_ = True
        
        if config_path == '':
            self.cfg_ = config_factory('detection_cvpr_2019')
        else:
            with open(config_path, 'r') as _f:
                self.cfg_ = DetectionConfig.deserialize(json.load(_f))

    def load_eval_object(self, nusc_, path):
        return DetectionEval(nusc_, config=self.cfg_, result_path= path, eval_set= self.eval_set_,
                          output_dir= self.output_dir_, verbose= self.verbose_)



def fuse_all_boxes(input_samples, 
                  primary_detection_object, 
                  primary_conf_thresh, 
                  secondary_detection_object, 
                  secondary_conf_thresh:int = 0.0, 
                  fusion_type:str = 'iou', 
                  iou_thresh:float = 0.1, 
                  distance_thresh:float = 0.5,
                  model = None):
    """
    This is the entrypoint function to the fusion land. 
    One needs to pass in the input samples, detection_object_1, detection_object_2,  and `fusion_type` 
    which can be any of iou (NMS), adaptive_nms (AdaNMS), bayesian (Bayesian fusion).


    Args:
        input_samples ([list]): list of nuscenes samples as input
        primary_detection_object ([DetectionEval]): Nusc detection object
        primary_conf_thresh ([float]): confidence threshold for primary_detection_object
        secondary_detection_object ([DetectionEval]): Nusc detection object
        secondary_conf_thresh (int, optional): confidence threshold for secondary_detection_object. Defaults to 0.0.
        fusion_type (str, optional): can be any of iou (NMS), adaptive_nms (AdaNMS), bayesian (Bayesian fusion).. Defaults to 'iou'.
        iou_thresh (float, optional):IOU threshold. Defaults to 0.1.
        distance_thresh (float, optional): Defaults to 0.5.
        model ([type], optional): CLOCs3D model
        
    Returns:
        [list]: list of fused boxes
    """
    
    if fusion_type == 'iou':
        fuse_fn = nms
    elif fusion_type == 'distance':
        fuse_fn = nms_distance
    elif fusion_type == 'adaptive_nms':
        fuse_fn = distance_adaptive_nms
    elif fusion_type == 'bayesian':
        fuse_fn = nms_bayesian
    elif fusion_type == "nms_max":
        fuse_fn = nms_max
    elif fusion_type == "nn_fusion":
        fuse_fn = NN_fusion


    result_boxes = defaultdict(list)

    for sample_idx, sample_token in enumerate(input_samples):
        boxes_to_fuse_cp = []
        for box in primary_detection_object.pred_boxes[sample_token]:
            if box.detection_score >= primary_conf_thresh:
                boxes_to_fuse_cp.append(box)

        boxes_to_fuse_fcos = []
        for box in secondary_detection_object.pred_boxes[sample_token]:
            if box.detection_score > secondary_conf_thresh:
                boxes_to_fuse_fcos.append(box)
        
        if fusion_type == "nms_cp" or fusion_type == "bayesian_updated":
            fused_boxes = fuse_fn(boxes_to_fuse_cp, boxes_to_fuse_fcos,  iou_threshold=iou_thresh)    
        elif fusion_type == "nn_fusion":
            fused_boxes = fuse_fn(model, boxes_to_fuse_cp, boxes_to_fuse_fcos,  iou_threshold=iou_thresh)       
        else:
            boxes_to_fuse_cp.extend(boxes_to_fuse_fcos)
            fused_boxes, _ = fuse_fn(boxes_to_fuse_cp, confidence_thresh=0.0, iou_threshold=iou_thresh)
        
        for box in fused_boxes:
            result_boxes[sample_token].append(box.serialize())

    return result_boxes
    


def multithreaded_fusion(sample_tokens,
                         primary_detection_object, 
                         primary_conf_thresh,
                         secondary_detection_object, 
                         secondary_conf_thresh, 
                         fusion_type, 
                         iou_thresh, 
                         distance_thresh = 0.5, 
                         n_threads = 10):
    
    sample_slices = np.array_split(sample_tokens, n_threads)
    
    pool = ThreadPool(n_threads)
    results = pool.map(partial(fuse_all_boxes, primary_detection_object = primary_detection_object, primary_conf_thresh = primary_conf_thresh, secondary_detection_object = secondary_detection_object,
                                         secondary_conf_thresh = secondary_conf_thresh, fusion_type = fusion_type, iou_thresh = iou_thresh), sample_slices)
    pool.close()
    pool.join()
    result_boxes = {}
    for res in results:
        result_boxes.update(res)
    return result_boxes
    
def distance(box_a, box_b):
    return np.linalg.norm(np.array(box_a.translation[:2]) - np.array(box_b.translation[:2]))
    
 
def nms(original_boxes, confidence_thresh, iou_threshold):
    
    boxes_probability_sorted = sorted(original_boxes, key = lambda x : x.detection_score, reverse= True)
    box_indices = np.arange(0, len(boxes_probability_sorted))
    suppressed_box_indices = []
    tmp_suppress = []

    while len(box_indices) > 0:

        if box_indices[0] not in suppressed_box_indices:
            selected_box = box_indices[0]
            tmp_suppress = []
            
            
            
            for i in range(len(box_indices)):
                if box_indices[i] != selected_box and \
                   boxes_probability_sorted[selected_box].detection_name == boxes_probability_sorted[box_indices[i]].detection_name and\
                   box_indices[i] not in suppressed_box_indices:
                    if distance(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]]) > 10:
                        continue
                    
                    selected_iou = rotated_iou(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]])[0]
                    
                    if selected_iou > iou_threshold:
                        
                        suppressed_box_indices.append(box_indices[i])
                        tmp_suppress.append(i)

            
        box_indices = np.delete(box_indices, tmp_suppress, axis=0)
        box_indices = box_indices[1:]

    preserved_boxes = np.delete(boxes_probability_sorted, suppressed_box_indices, axis=0)
    return preserved_boxes, suppressed_box_indices


def nms_max(original_boxes, confidence_thresh, iou_threshold):
    
    boxes_probability_sorted = sorted(original_boxes, key = lambda x : x.detection_score, reverse= True)
    box_indices = np.arange(0, len(boxes_probability_sorted))
    
    suppressed_box_indices = []
    tmp_suppress = []

    while len(box_indices) > 0:

        if box_indices[0] not in suppressed_box_indices:
            selected_box = box_indices[0]
            tmp_suppress = []
            
            max_iou = 0
            max_idx = -1
            
            for i in range(len(box_indices)):
                if box_indices[i] != selected_box and boxes_probability_sorted[selected_box].detection_name == boxes_probability_sorted[box_indices[i]].detection_name:
                    if distance(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]]) > 10:
                        continue
                    
                    selected_iou = rotated_iou(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]])[0]
                    
                    if selected_iou > iou_threshold and selected_iou > max_iou:
                        max_idx = i
                        max_iou = selected_iou

            if max_idx != -1:
                suppressed_box_indices.append(box_indices[i])
                tmp_suppress.append(i)
                        
        box_indices = np.delete(box_indices, tmp_suppress, axis=0)
        box_indices = box_indices[1:]

    preserved_boxes = np.delete(boxes_probability_sorted, suppressed_box_indices, axis=0)
    return preserved_boxes, suppressed_box_indices

def distance_adaptive_nms(original_boxes,confidence_thresh, iou_threshold):
    
    boxes_probability_sorted = sorted(original_boxes, key = lambda x : x.detection_score, reverse= True)
    box_indices = np.arange(0, len(boxes_probability_sorted))
    
    suppressed_box_indices = []
    tmp_suppress = []

    while len(box_indices) > 0:
        if box_indices[0] not in suppressed_box_indices:
            selected_box = box_indices[0]
            tmp_suppress = []
            adaptive_thresh = iou_threshold/((boxes_probability_sorted[selected_box].ego_dist + 20)//20)
            
            for i in range(len(box_indices)):
                if box_indices[i] != selected_box and boxes_probability_sorted[selected_box].detection_name == boxes_probability_sorted[box_indices[i]].detection_name \
                    and box_indices[i] not in suppressed_box_indices:
                    if distance(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]]) > 10:
                        continue
                    
                    selected_iou = rotated_iou(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]])[0]
                    
                    if selected_iou > adaptive_thresh :
                        suppressed_box_indices.append(box_indices[i])
                        tmp_suppress.append(i)
            
        box_indices = np.delete(box_indices, tmp_suppress, axis=0)
        box_indices = box_indices[1:]

    preserved_boxes = np.delete(boxes_probability_sorted, suppressed_box_indices, axis=0)
    return preserved_boxes, suppressed_box_indices

    

def nms_distance(original_boxes, confidence_thresh, iou_threshold):
    
    boxes_probability_sorted = sorted(original_boxes, key = lambda x : x.detection_score, reverse= True)
    box_indices = np.arange(0, len(boxes_probability_sorted))

    suppressed_box_indices = []
    tmp_suppress = []

    while len(box_indices) > 0:
        if box_indices[0] not in suppressed_box_indices:
            selected_box = box_indices[0]
            tmp_suppress = []
            

            for i in range(len(box_indices)):
                if box_indices[i] != selected_box and boxes_probability_sorted[selected_box].detection_name == boxes_probability_sorted[box_indices[i]].detection_name\
                    and box_indices[i] not in suppressed_box_indices:
                    selected_distance = distance(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]])
                    dist_threshold = selected_distance * 0.04 + 2.0
                    if selected_distance < dist_threshold:
                        suppressed_box_indices.append(box_indices[i])
                        tmp_suppress.append(i)


        box_indices = np.delete(box_indices, tmp_suppress, axis=0)
        box_indices = box_indices[1:]

    preserved_boxes = np.delete(boxes_probability_sorted, suppressed_box_indices, axis=0)
    return preserved_boxes, suppressed_box_indices





def polygon_clip(subjectPolygon, clipPolygon):
   """ Clip a polygon with another polygon.
   Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python
   Args:
     subjectPolygon: a list of (x,y) 2d points, any polygon.
     clipPolygon: a list of (x,y) 2d points, has to be *convex*
   Note:
     **points have to be counter-clockwise ordered**
   Return:
     a list of (x,y) vertex point for the intersection polygon.
   """
   def inside(p):
      return(cp2[0]-cp1[0])*(p[1]-cp1[1]) > (cp2[1]-cp1[1])*(p[0]-cp1[0])
 
   def computeIntersection():
      dc = [ cp1[0] - cp2[0], cp1[1] - cp2[1] ]
      dp = [ s[0] - e[0], s[1] - e[1] ]
      n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0]
      n2 = s[0] * e[1] - s[1] * e[0] 
      n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0])
      return [(n1*dp[0] - n2*dc[0]) * n3, (n1*dp[1] - n2*dc[1]) * n3]
 
   outputList = subjectPolygon
   cp1 = clipPolygon[-1]
 
   for clipVertex in clipPolygon:
      cp2 = clipVertex
      inputList = outputList
      outputList = []
      s = inputList[-1]
 
      for subjectVertex in inputList:
         e = subjectVertex
         if inside(e):
            if not inside(s):
               outputList.append(computeIntersection())
            outputList.append(e)
         elif inside(s):
            outputList.append(computeIntersection())
         s = e
      cp1 = cp2
      if len(outputList) == 0:
          return None
   return(outputList)

def poly_area(x,y):
    """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """
    return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))

def convex_hull_intersection(p1, p2):
    """ Compute area of two convex hull's intersection area.
        p1,p2 are a list of (x,y) tuples of hull vertices.
        return a list of (x,y) for the intersection and its volume
    """
    inter_p = polygon_clip(p1,p2)
    if inter_p is not None:
        hull_inter = ConvexHull(inter_p)
        return inter_p, hull_inter.volume
    else:
        return None, 0.0  

def box3d_vol(corners):
    ''' corners: (8,3) no assumption on axis direction '''
    a = np.sqrt(np.sum((corners[0,:] - corners[1,:])**2))
    b = np.sqrt(np.sum((corners[1,:] - corners[2,:])**2))
    c = np.sqrt(np.sum((corners[0,:] - corners[4,:])**2))
    return a*b*c

def is_clockwise(p):
    x = p[:,0]
    y = p[:,1]
    return np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)) > 0

def box3d_iou(corners1, corners2):
    ''' Compute 3D bounding box IoU.
    Input:
        corners1: numpy array (8,3), assume up direction is negative Y
        corners2: numpy array (8,3), assume up direction is negative Y
    Output:
        iou: 3D bounding box IoU
        iou_2d: bird's eye view 2D bounding box IoU
    todo (kent): add more description on corner points' orders.
    '''
    # corner points are in counter clockwise order
    rect1 = [(corners1[i,0], corners1[i,2]) for i in range(3,-1,-1)]
    rect2 = [(corners2[i,0], corners2[i,2]) for i in range(3,-1,-1)] 
    
    area1 = poly_area(np.array(rect1)[:,0], np.array(rect1)[:,1])
    area2 = poly_area(np.array(rect2)[:,0], np.array(rect2)[:,1])
   
    inter, inter_area = convex_hull_intersection(rect1, rect2)
    iou_2d = inter_area/(area1+area2-inter_area)
    ymax = min(corners1[0,1], corners2[0,1])
    ymin = max(corners1[4,1], corners2[4,1])

    inter_vol = inter_area * max(0.0, ymax-ymin)
    
    vol1 = box3d_vol(corners1)
    vol2 = box3d_vol(corners2)
    iou = inter_vol / (vol1 + vol2 - inter_vol)
    return iou, iou_2d


def get_3d_box(box_size, heading_angle, center):
    ''' Calculate 3D bounding box corners from its parameterization.
    Input:
        box_size: tuple of (length,wide,height)
        heading_angle: rad scalar, clockwise from pos x axis
        center: tuple of (x,y,z)
    Output:
        corners_3d: numpy array of shape (8,3) for 3D box cornders
    '''
    def roty(t):
        c = np.cos(t)
        s = np.sin(t)
        return np.array([[c,  0,  s],
                         [0,  1,  0],
                         [-s, 0,  c]])

    R = roty(heading_angle)
    l,w,h = box_size
    x_corners = [l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2]
    y_corners = [h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2]
    z_corners = [w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2]
    corners_3d = np.dot(R, np.vstack([x_corners,y_corners,z_corners]))
    corners_3d[0,:] = corners_3d[0,:] + center[0]
    corners_3d[1,:] = corners_3d[1,:] + center[1]
    corners_3d[2,:] = corners_3d[2,:] + center[2]
    corners_3d = np.transpose(corners_3d)
    return corners_3d


def rotated_iou(box1, box2):
    corners1 = get_3d_box(box1.size, box1.rotation[2], box1.translation)
    corners2 = get_3d_box(box2.size, box2.rotation[2], box2.translation)
    iou = box3d_iou(corners1, corners2)
    return iou



def distance(box_a, box_b):
    return np.linalg.norm(np.array(box_a.ego_translation[:2]) - np.array(box_b.ego_translation[:2]))
 

def calc_ap(true_positives, false_positives, conf, num_positives):
    """Computes AP:
    Code references https://github.com/nutonomy/nuscenes-devkit
    """
    true_positives = np.array(true_positives)
    false_positives = np.array(false_positives)
    conf = np.array(conf)
    if len(true_positives) == 0:
        return 0, 0

    # Sort in order of conf
    sortind = np.argsort(conf)[::-1]
    
    true_positives_sorted = true_positives[sortind]
    false_positives_sorted = false_positives[sortind]
    true_positives_sorted = np.cumsum(true_positives_sorted).astype(np.float)
    false_positives_sorted = np.cumsum(false_positives_sorted).astype(np.float)
    # Precision and Recall
    prec = true_positives_sorted / (false_positives_sorted + true_positives_sorted)
    
    if num_positives == 0:
        return np.mean(prec), 0.0
    rec = true_positives_sorted / float(num_positives)
    rec_interp = np.linspace(0, 1, 100)
    prec = np.interp(rec_interp, rec, prec, right=0)
    
    return np.mean(prec), np.mean(rec)



def bayesian_fusion_multiclass(match_boxes, pred_class):

    pred_class_scores = list(map(lambda x: x.detection_score, match_boxes))
    scores = np.zeros((match_boxes.shape[0], len(DETECTION_NAMES)))
    
    for i in range(scores.shape[0]):
        
        scores[i, :] = 0.1
        scores[i, DETECTION_NAMES.index(match_boxes[i].detection_name)] = pred_class_scores[i]

    log_scores = np.log(scores)
    sum_logits = np.sum(log_scores, axis=0)
    exp_logits = np.exp(sum_logits)
    out_score = exp_logits[DETECTION_NAMES.index(pred_class)] / np.sum(exp_logits)

    return out_score




def nms_bayesian(original_boxes, confidence_thresh, iou_threshold):
    
    preserved_boxes = []
    boxes_probability_sorted = sorted(original_boxes, key = lambda x : x.detection_score, reverse= True)
    box_indices = np.arange(0, len(boxes_probability_sorted))
    

    suppressed_box_indices = []
    tmp_suppress = []

    while len(box_indices) > 0:

        if box_indices[0] not in suppressed_box_indices:
            selected_box = box_indices[0]
            tmp_suppress = []
            adaptive_thresh = iou_threshold/((boxes_probability_sorted[selected_box].ego_dist + 30)//30)
            
            for i in range(len(box_indices)):
                if box_indices[i] != selected_box \
                    and box_indices[i] not in suppressed_box_indices:
                    
                    if distance(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]]) > 10:
                        continue
                    selected_iou = rotated_iou(boxes_probability_sorted[selected_box], boxes_probability_sorted[box_indices[i]])[0]
                
                    
                    if selected_iou > adaptive_thresh :
                        suppressed_box_indices.append(box_indices[i])
                        tmp_suppress.append(i)
            
            if len(tmp_suppress) > 0:
                match_bbox = np.array(boxes_probability_sorted)[np.array(box_indices)[tmp_suppress]]
                match_bbox = np.append(match_bbox, boxes_probability_sorted[selected_box])
                final_score = bayesian_fusion_multiclass(match_bbox, boxes_probability_sorted[selected_box].detection_name)
                selected_box_copy = deepcopy(boxes_probability_sorted[selected_box])
                selected_box_copy.detection_score = final_score
            else:
                selected_box_copy = deepcopy(boxes_probability_sorted[selected_box])
              
            preserved_boxes.append(selected_box_copy)
            
        box_indices = np.delete(box_indices, tmp_suppress, axis=0)
        box_indices = box_indices[1:]
    return np.array(preserved_boxes), suppressed_box_indices

def predict(model, inp):
    inp = torch.from_numpy(inp).type(torch.FloatTensor)
    return model(inp)

def NN_fusion(model, box_list1, box_list2, iou_threshold):    
    box_sorted_1 = sorted(box_list1, key = lambda x : x.detection_score, reverse= True)
    box_sorted_2 = sorted(box_list2, key = lambda x : x.detection_score, reverse= True)

    
    fused_boxes = []
    
    suppressed_box_indices = []
    tmp_suppress = []

    for idx1 in range(len(box_sorted_1)):
        max_iou = 0
        max_idx = -1
        for idx2 in range(len(box_sorted_2)):
            
            if idx2 not in suppressed_box_indices and  box_sorted_2[idx2].detection_name == box_sorted_1[idx1].detection_name:
                dist =distance( box_sorted_2[idx2], box_sorted_1[idx1])
                if dist > 10:
                    continue

                
                selected_iou = rotated_iou(box_sorted_1[idx1], box_sorted_2[idx2])[0]
                
                if selected_iou > 0:
                    suppressed_box_indices.append(idx2)
                    tmp_suppress.append(idx2)
                    if  selected_iou > iou_threshold:
                        max_iou = selected_iou
                        max_idx = idx2
        if max_idx != -1:
            max_score = 0
            for box_idx in range(len(tmp_suppress)):
                dist = distance(box_sorted_1[idx1], box_sorted_2[tmp_suppress[box_idx]])
                feature = [[max_iou, box_sorted_1[idx1].detection_score, box_sorted_2[tmp_suppress[box_idx]].detection_score, dist/10, box_sorted_1[idx1].ego_dist/100]]
                final_score  = predict(model, np.array(feature)).item()
                max_score = max(final_score, max_score)
            
            selected_box_copy = deepcopy(box_sorted_1[idx1])
            selected_box_copy.detection_score = max_score
            fused_boxes.append(selected_box_copy)
        else:
            fused_boxes.append(box_sorted_1[idx1])
        tmp_suppress = []
    
    box_sorted_2 = np.delete(box_sorted_2, suppressed_box_indices, axis=0)
    for i in range(len(box_sorted_2)):
        fused_boxes.append(box_sorted_2[i])
    return fused_boxes




def dump_pickle(dictionary, filename):
    with open(filename, 'wb') as handle:
        pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_pickle(filename):
    with open(filename, 'rb') as fp:
        data = pickle.load(fp)
    return data


def dump_results(result_boxes, metadata, filepath, args = None):
    fusion_results_dump = {}
    result_boxes_dump = {}
    for key in result_boxes.keys():
        result_boxes_dump[key] = list(map(lambda x: x, result_boxes[key]))
    fusion_results_dump['meta'] = metadata['meta']
    fusion_results_dump['meta']['use_camera'] = True
    fusion_results_dump['results'] = result_boxes_dump
    if args != None:
        filepath = str(filepath.split('.json')[0] + '_fusion_' + args.nms_type + '_selected_' + str(args.evaluate_selected) +  '_calibrated_' + str(args.calibrated)+ '_' + args.fcos+'_' + str(args.secondary_conf_thresh) +'_' + args.cp+'_' + str(args.cp_thresh) + '_iou_thresh_'+ str(args.iou_thresh) + '.json')
    with open(filepath, 'w') as f:
        json.dump(fusion_results_dump, f)

    return filepath
    

def save_metrics_results(results, json_filename, args = None):
    if args != None:
        json_filename = str(json_filename.split('.json')[0] + '_fusion_' + args.nms_type + '_selected_' + str(args.evaluate_selected) +  '_calibrated_' + str(args.calibrated)+ '_' + args.fcos+'_' + str(args.secondary_conf_thresh) +'_' + args.cp+'_' + str(args.cp_thresh) + '_iou_thresh_'+ str(args.iou_thresh) + '.json')
        
        args_path = json_filename.split('.json')[0] + '_args.json'
        with open(json_filename.split('.json')[0] + 'args.json', 'w') as fp:
            json.dump(args.__dict__, fp)
        
    with open(json_filename, 'w') as fp:
        json.dump(results, fp)
