import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from scipy.spatial.distance import cdist
import numpy as np
from tqdm import tqdm
import argparse

from dataset.skeleton_prediction_demo_dataset import SPDemoDataset

from utils.pc_utils import nms_swc_sphere, fps
from utils.kmeans import KMeans
from utils.det_soma import det_soma_pos

from dgcnn_model.dgcnn import DGCNN_SkelPred
from gae_model.link_pred_model import GAE
from gae_model.init_graph import build_link_with_swc
from gae_model.interp_graph import tracing_reconstruct_graph, gen_swc_file

"""
    Aim: To demonstrate the workflow of our proposed work, PointNeuron.
"""

preds_model = DGCNN_SkelPred(k=20)
preds_model.cuda()
preds_model.eval()
predc_model = GAE()
predc_model.cuda()
predc_model.eval()

save_dgcnn_path = './checkpoints/pretrained_skeleton_pred_module.pth'
print("Loading pretrained DGCNN encoder from {}".format(save_dgcnn_path))
checkpoint = torch.load(save_dgcnn_path)
preds_model.load_state_dict(checkpoint['model_state_dict'])

save_gae_path = './checkpoints/pretrained_connectivity_pred_module.pth'
print("Loading pretrained GAE from {}".format(save_gae_path))
predc_model.load_state_dict(torch.load(save_gae_path)['model_state_dict'])

def loadtiff3d(filepath):
    """Load a tiff file into 3D numpy array"""

    import tifffile as tiff
    a = tiff.imread(filepath)

    stack = []
    for sample in a:
        stack.append(np.rot90(np.fliplr(np.flipud(sample)))) # flipud: flips elements in up/down direction
                                                             # fliplf: flips elements in left/right direction
                                                             # rot90: rotate an array 90 degree
    out = np.dstack(stack)

    return out

def nn_distance(pc1, pc2, l1smooth=False):

    N = pc1.shape[1]
    M = pc2.shape[1]
    pc1_expand_tile = pc1.unsqueeze(2).repeat(1,1,M,1)
    pc2_expand_tile = pc2.unsqueeze(1).repeat(1,N,1,1)
    pc_diff = pc1_expand_tile - pc2_expand_tile

    pc_dist = torch.sum(pc_diff**2, dim=-1) + 1e-6 # (B,N,M)

    dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N)
    dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M)

    return dist1, idx1, dist2, idx2

def process_nms(end_points):
    # In this demo, a low overlapping threshold is used to obtain a compact neuron skeleton. 
    # Otherwise, the soma position needs to be manually inserted, and soma detection is not 
    # the focus of our work.

    skel = end_points['center'].squeeze().cpu().detach().numpy()
    softmax = nn.Softmax(dim=2)
    score = softmax(end_points['objectness_scores'])[:,:,1]
    pick_indx = nms_swc_sphere(radius=end_points['radius'].squeeze().cpu().detach().numpy(), 
                               obj_score=score.squeeze().cpu().detach().numpy(), 
                               xyz=skel,
                               overlap_threshold=0.05 
                               )

    return pick_indx

def save_in_pc(xyz, color=None, title='point'):
    # Pass xyz to Open3D.o3d.geometry.PointCloud and visualize
    import open3d as o3d
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    o3d.io.write_point_cloud("./result/{}.ply".format(title), pcd)

def find_train_img_indx(id):
    list_file = './train.txt'
    with open(list_file) as f:
        img_list = [x.strip() for x in f.readlines()]
        img_list = [x for x in img_list]
    
    # load image id
    id_list = [int(fn[6:]) for fn in img_list]
    img_indx = np.where(np.array(id_list)==id)[0].item()

    return img_indx

def voxel_2_point_transf(img_fp):
    
    img = loadtiff3d(img_fp)
    img = (img-img.min())/(img.max()-img.min()) # normalization
    x,y,z = np.where(img>0.2) # coordinates of foregroud pixels
    intensity = img[x,y,z]
    pts = np.concatenate((x[:,np.newaxis],y[:,np.newaxis],z[:,np.newaxis],intensity[:,np.newaxis]), axis=1)

    return img, pts
    

def pred__neuron_skel(centroids_selection, num_skel_pts, downsampling, img_xyz):

    if centroids_selection == 'fps' :
        centroids_list = img_xyz[fps(img_xyz[:,0:3], npoint=180),0:3]
    elif centroids_selection == 'kmeans':
        k = KMeans(K=180, max_iters=150, plot_steps=True)
        y_pred = k.predict(img_xyz)
        centroids_list = y_pred
    print("centroids selection compeleted.")

    skel_agg = None
    skel_feature_agg = None
    skel_radius_agg = None
    for centroid in tqdm(centroids_list):
        dataset = SPDemoDataset(input_pts=img_xyz, centroid_xyz=centroid, num_skel_pts=num_skel_pts)
        dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)

        with torch.no_grad():
            for input in dataloader:
                input = input.cuda().float()

                end_points = preds_model(input)

                if downsampling == '3dnms': # process nms algo
                    pick_indx = process_nms(end_points)
                elif downsampling == 'fps': # fps downsampling
                    pick_indx = fps(end_points['center'].squeeze().cpu().detach().numpy(), npoint=50)
                elif downsampling == 'uniform': # uniform sampling
                    pick_indx = np.unique(np.random.uniform(0,500,size=[50]).astype(np.int64))

                skel = end_points['center'][:,pick_indx, :].squeeze().cpu().detach().numpy()
                feature = end_points["input_feature"][:,pick_indx, :].squeeze().cpu().detach().numpy()
                radius = end_points["radius"][:,pick_indx].squeeze().cpu().detach().numpy()
                
                if skel_agg is None:
                    skel_agg = skel
                    skel_feature_agg = feature
                    skel_radius_agg = radius
                else:
                    skel_agg = np.concatenate((skel_agg, skel), axis=0)
                    skel_feature_agg = np.concatenate((skel_feature_agg, feature), axis=0)
                    skel_radius_agg = np.concatenate((skel_radius_agg, radius), axis=0)

    return skel_agg, skel_feature_agg, skel_radius_agg 

def pred_connectivity(skel_xyz_agg, skel_feature_agg, skel_radius_agg, id):
 
    with torch.no_grad():
        skel_xyz_agg = torch.from_numpy(skel_xyz_agg)[None,:,:].cuda().float()
        skel_feature_agg = torch.from_numpy(skel_feature_agg)[None,:,:].cuda().float()
        skel_radius_agg = torch.from_numpy(skel_radius_agg)[None,:,None].cuda().float()  

        A_init, valid_mask = build_link_with_swc(skel_xyz=skel_xyz_agg, is_train=False)

        skel_node_features = torch.cat([skel_feature_agg, skel_xyz_agg, skel_radius_agg], 2)

        A_pred = predc_model(skel_node_features, A_init)
        A_final = predc_model.recover_A(A_raw=A_pred, A_mask=valid_mask, t=0, skel_xyz=skel_xyz_agg)

        soma_loc = det_soma_pos(filepath='./data/sample.tif')
        save_book = tracing_reconstruct_graph(A=A_final, soma_loc=soma_loc, skel_xyz=skel_xyz_agg)
        gen_swc_file(skel_xyz=skel_xyz_agg, skel_rad=skel_radius_agg, A=A_final, save_book=save_book, id=id)

        return

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PointNeuron')
    parser.add_argument('--centroids_select', type=str, default='fps')
    parser.add_argument('--num_skel_pts', type=int, default=512)
    parser.add_argument('--downsample_method', type=str, default='3dnms')

    args = parser.parse_args()
    centroids_selection = args.centroids_select
    num_skel_pts = args.num_skel_pts
    ds_method = args.downsample_method

    print("##### Voxel-to-points Transformation")

    img_fp = './data/sample.tif'
    neuron_img, neuron_pts = voxel_2_point_transf(img_fp=img_fp)
    save_in_pc(xyz=neuron_pts[:,0:3], title='neuron_input_pts')
    print('Transmitting {} raw 3D microscopy image into {} neuron spatial points.'.format(neuron_img.shape, neuron_pts.shape[0]))
    print('The input points are saved in /result folder.')
    print('\n')

    print("##### Neuron Skeleton Prediction #####")
    skel, geom_feature, rad = pred__neuron_skel(centroids_selection=centroids_selection, num_skel_pts=num_skel_pts, downsampling=ds_method, img_xyz=neuron_pts)
    print("# Producing {} neuron skeletal spheres after aggregation and downsampling.".format(skel.shape[0]))

    save_in_pc(xyz=skel, title='neuron_skel')
    print("# The compact neuron skeleton has been saved in /result folder.")
    print("\n")

    print("##### Connectivity Reconstruction #####")
    pred_connectivity(skel, geom_feature, rad, id)
    print("The reconstructed neuron is saved in /result folder as SWC file.")
    print("Done")