import numpy as np

def save_in_pc(xyz, color=None, title='point'):
    # Pass xyz to Open3D.o3d.geometry.PointCloud and visualize
    import open3d as o3d
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    o3d.io.write_point_cloud("./code/img_save/{}.ply".format(title), pcd)

def loadswc(filepath):
    '''
    Load swc file as a N X 7 numpy array
    '''
    swc = []
    hash_line = []
    with open(filepath) as f:
        lines = f.read().split("\n")
        for l in lines:
            if l.startswith('#'):
                hash_line.append(l)
            else:
                cells = l.split(' ')
                if len(cells) ==7:
                    cells = [float(c) for c in cells]
                    swc.append(cells)

    return np.array(swc)

def tracing_from_soma(save_book, A, curr_id, par_id):
    
    if curr_id == 413:
        print()

    indx = curr_id
    queue = np.where(A[indx,:]==1)[0].tolist()
    
    for i in np.where(A[indx,:]==1)[0]:

        # print("indx: {}, idx to go: {}".format(indx, queue))
        # print(np.isin(curr_id, save_book[:,0]))
        if not np.isin(curr_id, save_book[:,0]):
            # print("curr_id: {} par_id: {}".format(curr_id, par_id))
            save_book = np.concatenate((save_book,np.array([[curr_id, par_id]])),axis=0)
        # print(save_book)
        if i != par_id and not np.isin(i, save_book[:,0]):
            save_book = tracing_from_soma(save_book, A, curr_id=i, par_id=curr_id)
        queue.pop(0)
    
    if not queue:
        return save_book

def tracing_reconstruct_graph(A, soma_loc, skel_xyz):
    '''
        assign parent-child relation to the reconstructed graph
        input:
            A, adjacency matrix
            soma_loc, soma location
            skel_xyz, xyz of skeleton points 
        output:
            reconstr_swc: the reconstructed swc file
    '''
    print("Reconstructing the point connections recursively...")
    A = A.squeeze().cpu().detach().numpy()
    skel_xyz = skel_xyz.squeeze().cpu().detach().numpy()

    linked_node = np.unique(np.where(A==1)[0]) # the num of linked node is 198 rather than 200
    n = linked_node.shape[0]
    linked_node_xyz = skel_xyz[linked_node, :]

    from scipy.spatial.distance import cdist
    root_node = linked_node[np.argsort(cdist(np.array(soma_loc).reshape(-1,3), linked_node_xyz)).squeeze()[0]]
    save_book = np.array([[-2, -2]])
    save_book = tracing_from_soma(save_book, A, curr_id=root_node, par_id=-1)

    return save_book

def save_swc_file(path, swc):

    lines = ''
    for i in range(0, swc.shape[0]):
        lines = lines + '{:.0f} {:.0f} {:.3f} {:.3f} {:.3f} {:.3f} {:.0f}\n'.format(swc[i,0], swc[i,1], swc[i,2], swc[i,3], swc[i,4], swc[i,5], swc[i,6])
    # print(lines)

    with open(path, 'w') as f:
        f.write(lines)

def cal_distance_child_par(swc):
    
    dist = np.zeros((swc.shape[0]-1, 1))
    for i in range(1, swc.shape[0]):
        child_loc = swc[i, 2:5]
        par_id = int(swc[i, -1])
        par_loc = swc[par_id-1, 2:5]
        from scipy.spatial.distance import cdist
        dist[i-1,0] = cdist(child_loc.reshape(1,3), par_loc.reshape(1,3))

    print("max dist: {}".format(np.max(dist)))
    print(np.argmax(dist))
    print("average dist: {}".format(np.average(dist)))
    print("median dist: {}".format(np.median(dist)))
    
    return
         

def gen_swc_file(skel_xyz, skel_rad, A, save_book, id):
    save_book = save_book[1:,:]
    skel_xyz = skel_xyz.squeeze().cpu().detach().numpy()
    skel_rad = skel_rad.squeeze().cpu().detach().numpy()
    A = A.squeeze().cpu().detach().numpy()
    n = save_book.shape[0]
    xyz = np.zeros((n,3))
    rad = np.zeros((n,1))

    # to check all connected nodes are covered in save_book
    # indx1 = np.where(np.sum(A, axis=1)!=0)[0]
    # save_in_pc(xyz=skel_xyz[indx1,:], title='indx1')
    # indx2 = save_book[:,0]
    # save_in_pc(xyz=skel_xyz[indx2,:], title='indx2')

    for i in range(0, n):
        xyz[i, 0:3] = skel_xyz[save_book[i,0],:]
        rad[i] = skel_rad[save_book[i,0]]

    id_book = np.zeros_like(save_book)
    id_book[0,1] = -1
    for i in range(1,n+1):
        curr_id = save_book[i-1,0]
        id_book[i:,1][np.where(save_book[i:,1]==curr_id)[0]] = i 
        id_book[i-1,0] = i

    sampleID = id_book[:,[0]]
    typeID = np.zeros((n,1)) # to be undefined
    parentID = id_book[:,[1]]
    swc = np.concatenate((sampleID, typeID, xyz, rad, parentID), axis=1)

    # cal_distance_child_par(swc)

    swc_fp = './result/neuron_reconst.swc'.format(id)
    save_swc_file(swc_fp, swc)
    print("The reconstructed swc file has been saved in -result- folder.")

    return