import numpy as np
from scipy.spatial.distance import cdist

def crop_processing(data_dict, num_near_pts=4096, centroid_xyz=None):
    input_pts = data_dict['input']

    # randomly select a centroid and its near points
    near_indx = np.argsort(cdist(centroid_xyz.reshape(1, 3), input_pts[:,0:3]), axis=1)[:, 1:1+num_near_pts].squeeze()

    data_dict['input'] = input_pts[near_indx, :]

    return data_dict

def cal_inter_sphere(r1, r2, xyz1, xyz2):

    d = np.float32(cdist(xyz1.reshape(-1,3), xyz2.reshape(-1,3)).reshape(-1))

    # to check: r1 + r2 => d. If not, set r2=0 and d=r1
    indx = ~np.logical_and(abs(r1-r2)<d, d<(r1+r2))
    r2[indx] = 0
    d[indx] = r1

    inter_vol = (np.pi*(r1+r2-d)**2*(d**2+2*d*(r1+r2)-3*(r1**2+r2**2)+6*r1*r2)) / (12*d)

    return inter_vol, (~indx).sum()

def nms_swc_sphere(radius, obj_score, xyz, overlap_threshold=0.25):

    # cal volume of each swc ball
    vol = (4.0/3.0) * np.pi * (radius**3)

    I = np.argsort(obj_score.squeeze())
    pick = []

    dict_num_pts_deleted = {}
    while (I.size!=0):
        last = I.size
        # pick the point with the largest score
        i = I[-1]
        # pick.append(i)

        # cal IOU of two intersected spheres
        # https://math.stackexchange.com/questions/2705706/volume-of-the-overlap-between-2-spheres
        # for idx in I[:last-1]:
        r1, xyz1 = radius[i], xyz[i,:]
        r2, xyz2 = radius[I[:last-1]], xyz[I[:last-1],:]
        inter, numInteract = cal_inter_sphere(r1, r2, xyz1, xyz2)
        o = inter / (vol[i] + vol[I[:last-1]] - inter)
        
        # in ball_query to check IOU overlapping of each point

        # thresholding value depended on vol[i]
        # overlap_threshold = (1/2)*vol[i]
        # I = np.delete(I, np.concatenate(([last-1], np.where(inter>overlap_threshold)[0])))
        #  thresholding value depended on a predefined parameter
        pts_deleted = np.concatenate(([last-1], np.where(o>overlap_threshold)[0]))
        dict_num_pts_deleted[obj_score[i]] = numInteract# pts_deleted.size
        I = np.delete(I, pts_deleted)

        if numInteract > 1:
            pick.append(i)

    return np.array(pick)

def fps(pc, npoint):
    """Inputs:
            pc: point cloud [x, y, z]
            npoint: target number
        Output:
            pc_sampled [x, y, z]
    """
    N = pc.shape[0]
    centroids = np.zeros(npoint, dtype=np.int64)
    distance_ = np.ones(N) * 1e10
    farthest = np.random.randint(0, npoint, dtype=int)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = pc[farthest, :]
        dist = np.sum((pc - centroid) ** 2, -1)
        mask = dist < distance_
        distance_[mask] = dist[mask]
        farthest = np.argmax(distance_, -1)
    
    return centroids