import torch
import copy
import numpy as np

def nn_distance(pc1, pc2, k, is_max):
    B, N, _ = pc1.shape
    _, M, _ = pc2.shape
    dist = -2 * torch.matmul(pc1, pc2.permute(0,2,1))
    dist += torch.sum(pc1 ** 2, -1).view(B, N, 1)
    dist += torch.sum(pc2 ** 2, -1).view(B, 1, M)
    dist += 1e-6
    pc_dist = torch.sqrt(dist)

    top_dist, k_nn = torch.topk(pc_dist, k, dim=2, largest=is_max)

    return k_nn

def loadswc(filepath):
    '''
    Load swc file as a N X 7 numpy array
    '''
    swc = []
    hash_line = []
    with open(filepath) as f:
        lines = f.read().split("\n")
        for l in lines:
            if l.startswith('#'):
                hash_line.append(l)
            else:
                cells = l.split(' ')
                if len(cells) >= 7:
                    cells = [float(c) for c in cells]
                    swc.append(cells)

    return np.array(swc)

def build_link_with_swc(skel_xyz, is_train):
    '''
        simply build link between two near swc sphere
    '''
    
    bn, pn = skel_xyz.size()[0], skel_xyz.size()[1] # bn: batch, pn: # of skel_pts

    print("Loading the graph adjacency matrix produced by APP2 algorithm...")
    swc_from_app2 = np.genfromtxt("./gae_model/graph_init_from_app2.txt", delimiter=' ').astype(np.float32)
    swc = torch.from_numpy(swc_from_app2).cuda().float().unsqueeze(0)

    linked_node  = nn_distance(pc1=swc[:,:,2:5], pc2=skel_xyz, k=1, is_max=False) # # of linked node is less than # of swc pts
    linked_node = np.empty((swc.shape[1])).astype(np.int64)
    for i in range(swc.shape[1]):
        items = nn_distance(pc1=swc[:,i,2:5][None,:], pc2=skel_xyz, k=pn, is_max=False)
        x = 0
        temp_node = items[:,:,x].item()
        while np.isin(temp_node, linked_node):
            x+=1
            temp_node = items[:,:,x].item()
        linked_node[i] = temp_node

    A = torch.zeros((bn, pn, pn)).float().cuda() # Adjacency Matrix [batch_size, num_skel_pts, num_skel_pts]

    # initilize the adjacency matrix depending on the par_child relation
    for curr_indx in range(1, swc.shape[1]+1):
        par_indx = swc_from_app2[curr_indx-1, 6].astype(np.int64) - 1
        if par_indx > swc.shape[1]:
            break
        if par_indx >= 0:
            A[torch.arange(bn)[:, None, None], linked_node[curr_indx-1], linked_node[par_indx]] = 1
            A[torch.arange(bn)[:, None, None], linked_node[par_indx], linked_node[curr_indx-1]] = 1

    # build valid mask and known mask
    knn_skel  = nn_distance(pc1=skel_xyz, pc2=skel_xyz, k=pn, is_max=True)
    valid_k = 10

    if is_train:
        valid_mask = None
    else:
        valid_mask = copy.deepcopy(A)
        valid_mask[torch.arange(bn)[:, None, None], torch.arange(pn)[None, :, None], knn_skel[:, :, 1:valid_k]] = 1 # link the 8 nearest skel pts
        valid_mask[torch.arange(bn)[:, None, None], knn_skel[:, :, 1:valid_k], torch.arange(pn)[None, :, None]] = 1 

    return A, valid_mask