import torch
import torch.nn as nn
import numpy as np
import sys
import os

OBJECTNESS_CLS_WEIGHTS = [0.5,0.5] 

def nn_distance(pc1, pc2, l1smooth=False):
    """
    Input:
        pc1: (B,N,C) torch tensor
        pc2: (B,M,C) torch tensor
        l1smooth: bool, whether to use l1smooth loss
        delta: scalar, the delta used in l1smooth loss
    Output:
        dist1: (B,N) torch float32 tensor
        idx1: (B,N) torch int64 tensor
        dist2: (B,M) torch float32 tensor
        idx2: (B,M) torch int64 tensor
    """
    N = pc1.shape[1]
    M = pc2.shape[1]
    pc1_expand_tile = pc1.unsqueeze(2).repeat(1,1,M,1)
    pc2_expand_tile = pc2.unsqueeze(1).repeat(1,N,1,1)
    pc_diff = pc1_expand_tile - pc2_expand_tile


    pc_dist = torch.sum(pc_diff**2, dim=-1) + 1e-6 # (B,N,M)

    dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N)
    dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M)

    return dist1, idx1, dist2, idx2

def compute_proposal_loss(end_points):
    pred_center = end_points['proposal_xyz']
    B = pred_center.shape[0]
    flag = ~torch.all(end_points['center_label']==0, axis=2)
    gt_center = end_points['center_label']

    center_loss_list = torch.zeros(B).cuda()
    for batch_id in range(B):
        skel = pred_center[batch_id, :][np.newaxis,:]
        zero_flag = flag[batch_id,:]
        gt = gt_center[batch_id,zero_flag,:][np.newaxis,:]

        dist1, _, dist2, _ = nn_distance(skel, gt) # dist1: [batch, num_skel], dist2: [batch, ]
        center_loss = torch.sum(torch.sqrt(dist1+1e-6)) + torch.sum(torch.sqrt(dist2+1e-6))
        center_loss_list[batch_id] = center_loss

    return torch.mean(center_loss_list)

def compute_objectness_loss(input, end_points):

    # near threshold calculated by the dilated swc
    aggregated_vote_xyz = end_points['center']
    gt_extra_center = end_points['extra_nodes'][:,:,0:3]
    # gt_extra_center = end_points['center_label']
    radius = end_points['extra_nodes'][:,:,3]
    # radius = end_points['radius_label']
    flag = ~torch.all(end_points['extra_nodes']==0, axis=2)
    # flag = ~torch.all(end_points['center_label']==0, axis=2)
    B = gt_extra_center.shape[0]
    K = aggregated_vote_xyz.shape[1]

    objectness_label = torch.zeros((B,K), dtype=torch.long).cuda()
    objectness_mask = torch.zeros((B,K)).cuda()
    for batch_id in range(B):
        agg_vote_xyz = aggregated_vote_xyz[batch_id,:][np.newaxis,:]
        zero_flag = flag[batch_id,:]
        gt_extra_xyz = gt_extra_center[batch_id,zero_flag,:][np.newaxis,:]
        near_threshold = radius[batch_id,zero_flag][np.newaxis,:]

        dist1, ind1, dist2, _ = nn_distance(agg_vote_xyz, gt_extra_xyz)
        # print("dist to gt: {}".format(dist1))
        near_threshold = near_threshold[:,ind1.squeeze()]
        # print("radius per gt: {}".format(near_threshold))
        euclidean_dist1 = torch.sqrt(dist1+1e-6)
        
        objectness_label[batch_id,:][(euclidean_dist1<near_threshold).squeeze(0)] = 1 
        objectness_mask[batch_id,:][(euclidean_dist1<near_threshold).squeeze(0)] = 1
        objectness_mask[batch_id,:][(euclidean_dist1>near_threshold*2).squeeze(0)] = 1

    # objectness_label = end_points['vote_label_mask'].long()
    # B = objectness_label.shape[0]
    # K = objectness_label.shape[1]
    # objectness_mask = None

    # Compute objectness loss
    objectness_scores = end_points['objectness_scores']
    criterion = nn.CrossEntropyLoss(torch.Tensor(OBJECTNESS_CLS_WEIGHTS).cuda(), reduction='none')
    objectness_loss = criterion(objectness_scores.transpose(2,1), objectness_label)
    objectness_loss_without_mask = torch.sum(objectness_loss)/(B*K)
    # objectness_loss = torch.sum(objectness_loss * objectness_mask)/(torch.sum(objectness_mask)+1e-6)

    return objectness_loss, objectness_label, objectness_mask, objectness_loss_without_mask

def compute_center_loss(end_points):

    pred_center = end_points['center']
    B = pred_center.shape[0]
    flag = ~torch.all(end_points['center_label']==0, axis=2)
    gt_center = end_points['center_label']

    center_loss_list = torch.zeros(B).cuda()
    for batch_id in range(B):
        skel = pred_center[batch_id, :][np.newaxis,:]
        zero_flag = flag[batch_id,:]
        gt = gt_center[batch_id,zero_flag,:][np.newaxis,:]

        dist1, _, dist2, _ = nn_distance(skel, gt) # dist1: [batch, num_skel], dist2: [batch, ]
        center_loss = torch.sum(torch.sqrt(dist1+1e-6)) + torch.sum(torch.sqrt(dist2+1e-6))
        center_loss_list[batch_id] = center_loss

    return torch.mean(center_loss_list)

def compute_radius_loss(end_points): # '''find out the radius of the swc point being closest to aggregated_vote_pts'''
    radius = end_points['radius']

    aggregated_vote_xyz = end_points['center']
    gt_center = end_points['center_label'][:,:,0:3]
    B = gt_center.shape[0]
    K = aggregated_vote_xyz.shape[1]
    dist1, ind1, dist2, _ = nn_distance(aggregated_vote_xyz, gt_center)
    radius_label = torch.zeros(B,K).cuda()
    for batch_indx in range(B):
        radius_label[batch_indx, :] = end_points['radius_label'][batch_indx, ind1[batch_indx, :]]

    criterion = nn.L1Loss()
    radius_loss = criterion(radius, radius_label)

    return radius_loss

def confidence_score_loss(end_points):

    pred_xyz = end_points['center']
    gt_extra_center = end_points['extra_nodes'][:,:,0:3]
    radius = end_points['extra_nodes'][:,:,3]
    flag = ~torch.all(end_points['extra_nodes']==0, axis=2)
    B = gt_extra_center.shape[0]
    K = pred_xyz.shape[1]

    conf_score_label = torch.zeros((B,K), dtype=torch.float32).cuda()
    for batch_id in range(B):
        pred_xyz_one_batch = pred_xyz[batch_id,:][np.newaxis,:]
        zero_flag = flag[batch_id,:]
        gt_xyz_one_batch = gt_extra_center[batch_id,zero_flag,:][np.newaxis,:]
        near_threshold = radius[batch_id,zero_flag][np.newaxis,:]

        dist1, ind1, _, _ = nn_distance(pred_xyz_one_batch, gt_xyz_one_batch)
        near_threshold = near_threshold[:,ind1.squeeze()]
        euclidean_dist1 = torch.sqrt(dist1+1e-6)
        
        conf_score_label[batch_id,:][(euclidean_dist1<near_threshold).squeeze(0)] = 1
        # conf_score_label[batch_id,:] = 1 / euclidean_dist1

    # label shape is [batch, 512]
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(end_points['objectness_scores'], conf_score_label)
    sigm = nn.Sigmoid()
    # conf_score_sum = sigm(end_points['objectness_scores']).sum()

    return loss, conf_score_label

def compute_cover_ratio(input=None, end_points=None):

    gt = end_points['extra_nodes'][:,:,0:3]
    radius = end_points['extra_nodes'][:,:,3]
    flag = ~torch.all(end_points['extra_nodes']==0, axis=2)
    B = end_points['center'].shape[0]
    K = end_points['center'].shape[1]

    value_list = []
    for batch_idx in range(B):
        if input is None:
            # if sigmoid:
            #     sigm = nn.Sigmoid()
            #     objectness_socre = sigm(end_points['objectness_scores'])
            # else:
            objectness_socre = torch.argmax(end_points['objectness_scores'], 2).bool()
            pc = end_points['center'][batch_idx, objectness_socre[batch_idx, :], :].view(1,-1,3)
            if pc.shape[1] < 5:
                continue
        else:
            pc = input[batch_idx,:,0:3].view(1,-1,3)
        zero_flag = flag[batch_idx,:]
        gt_xyz = gt[batch_idx,zero_flag,:].view(1,-1,3)
        near_threshold = radius[batch_idx,zero_flag].view(1,-1)

        dist1, ind1, _, _ = nn_distance(pc, gt_xyz)
        euclidean_dist1 = torch.sqrt(dist1+1e-6)
        near_threshold = near_threshold[:,ind1.squeeze()]
        cover_ratio = (euclidean_dist1 < near_threshold).sum() / K
        value_list.append(cover_ratio.item())

    return np.average(value_list)

def dgcnn_get_loss(input, end_points):
    # proposal_loss = compute_proposal_loss(end_points)
    # end_points['proposal_loss'] = proposal_loss

    # conf_score_loss, conf_score_label = confidence_score_loss(end_points)
    # end_points['confidence_score_loss'] = conf_score_loss
    # end_points['confidence_score_label'] = conf_score_label

    objectness_loss, objectness_label, objectness_mask, objectness_loss_without_mask = compute_objectness_loss(input, end_points)
    end_points['objectness_loss'] = objectness_loss
    end_points['objectness_label'] = objectness_label
    end_points['objectness_mask'] = objectness_mask
    end_points['objectness_loss_without_mask'] = objectness_loss_without_mask

    center_loss = compute_center_loss(end_points)
    end_points['center_loss'] = center_loss

    radius_loss = compute_radius_loss(end_points)
    end_points['radius_loss'] = radius_loss
    
    # original
    loss = 10*objectness_loss_without_mask + center_loss + radius_loss
    # setting 1
    # loss = 10*objectness_loss_without_mask + radius_loss
    # setting 2
    # loss = center_loss + radius_loss

    loss *= 10
    end_points['total_loss'] = loss

    obj_pred_val = torch.argmax(end_points['objectness_scores'], 2) # B,K
    obj_acc = torch.sum(obj_pred_val==objectness_label.long()).float()/(objectness_label.shape[0]*objectness_label.shape[1])
    end_points['obj_acc'] = obj_acc

    end_points['num_skel_pred'] = torch.argmax(end_points['objectness_scores'], 2).sum() / end_points['objectness_scores'].shape[0]
    end_points['num_skel_label'] = end_points['objectness_label'].sum() / end_points['objectness_scores'].shape[0]

    end_points['input_cover_ratio'] = compute_cover_ratio(input, end_points)
    end_points['pred_cover_ratio'] = compute_cover_ratio(end_points=end_points)# end_points['num_skel_label'] / end_points['objectness_scores'].shape[1]

    return loss, end_points