import os
import sys
sys.path.append(os.getcwd())
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse

from dgcnn_model.proposal_module import ProposalModule

"""
    Aim: To use DGCNN as an encoder to genrate the rich geometric feature for neuron points.
    Referring to https://github.com/antao97/dgcnn.pytorch
"""

def knn(x, k):
    '''
        Output: return the 20 nearest point index for each input point
    '''
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def get_graph_feature(x, k=20, idx=None, dim4=False):
    batch_size = x.size(0)
    feature_dim = x.size(1)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim4 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 0:3], k=k)
    device = x.device

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base # idx: [B, NUM_PTS, K]

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature      # (batch_size, 2*num_dims, num_points, k)

def cat_near_pts_feature(xyz, x, k=20):
    batch_size = x.size(0)
    feature_dim = x.size(1)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)

    idx = knn(xyz, k=k)   # (batch_size, num_points, k)

    device = x.device

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base # idx: [B, NUM_PTS, K]

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    feature = feature.permute(0, 3, 1, 2).contiguous()

    return feature


class DGCNN_SkelPred(nn.Module):
    def __init__(self, k=20):
        super(DGCNN_SkelPred, self).__init__()

        # ----------------------------------------DGCNN----------------------------------------
        self.k = k

        # self.tnet = Transform_Net()
        
        self.bn1 = nn.BatchNorm2d(64, track_running_stats=False)
        self.bn2 = nn.BatchNorm2d(64, track_running_stats=False)
        self.bn3 = nn.BatchNorm2d(64, track_running_stats=False)
        self.bn4 = nn.BatchNorm2d(64, track_running_stats=False)
        self.bn5 = nn.BatchNorm2d(64, track_running_stats=False)
        self.bn6 = nn.BatchNorm1d(1024, track_running_stats=False)

        self.conv1 = nn.Sequential(nn.Conv2d(8, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, 1024, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))

        # ----------------------------------------MLP----------------------------------------
        input_channels = 1024 + 64*3

        self.pnet = ProposalModule(in_channels=input_channels, out_channels=6)
        

    def forward(self, x):
        x = x.transpose(2,1) # -> [B, 4, num_points]
        end_points = {}

        batch_size = x.size(0)
        num_points = x.size(2)
        xyz = x[:,0:3,:]
        end_points['input_xyz'] = xyz.transpose(2,1)

        x = get_graph_feature(x, k=self.k, dim4=True)   # (batch_size, 4, num_points) -> (batch_size, 4*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 4*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)      # (batch_size, 64*3, num_points)

        x = self.conv6(x)                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        x = x.repeat(1, 1, num_points)          # (batch_size, 1024, num_points)
        x = torch.cat((x, x1, x2, x3), dim=1)   # (batch_size, 1024+64*3, num_points)

        end_points['input_feature'] = x.transpose(2,1)
        end_points = self.pnet(x, end_points)

        return end_points

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Point Cloud Recognition')
    parser.add_argument('--k', type=int, default=20, metavar='N',
                        help='Num of nearest neighbors to use')
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='dropout rate')
    parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
                        help='Dimension of embeddings')
    args = parser.parse_args()

    input = torch.randn(1, 512, 4, dtype=torch.float32).cuda()
    model = DGCNN_SkelPred(k=20, num_skel_points=100).cuda()

    with torch.no_grad():
        output = model(input)

    print(output.shape)
    