import torch
import torch.nn as nn
import torch.nn.functional as F

class SPE(nn.Module):
    def __init__(self, args):
        super(SPE, self).__init__()
        self.args = args
        self.initialize_integral()

        if self.args.encoder == 'vgg':
            if self.args.level == 3:
                input_dim = [512, 128, 16]
            elif self.args.level == 3.1:
                input_dim = [512, 128, 64]
            elif self.args.level == 3.2:
                input_dim = [256, 128, 16]
            self.cin = [c for c in input_dim]
            self.cout = self.args.hidden_dim

        # PID coefficients
        if self.args.pid_k == 'const':
            self.kp = torch.tensor(self.args.kp, dtype=torch.float32)
            self.ki = torch.tensor(self.args.ki, dtype=torch.float32)
            self.kd = torch.tensor(self.args.kd, dtype=torch.float32)

        # input dimension
        if self.args.Optimizer_input in ['concat']:
            self.cin = [c*2 for c in self.cin]
        elif self.args.Optimizer_input in ['resconcat']:
            self.cin = [c*3 for c in self.cin]

        # D branch
        if self.args.require_jac:
            # if self.args.j_ver in ['long', 'lat', 'rot']:
            #     jc = 1
            # elif self.args.j_ver == 'trans':
            #     jc = 2
            # elif self.args.j_ver == 'sum':
            #     jc = 1
            # else:
            #     jc = 3
            jc = 1
            self.cin = [c + input_dim[i] * jc for i, c in enumerate(self.cin)]

        # I branch
        if self.args.require_int:
            self.cin = [c + input_dim[i] * self.args.integral_num for i, c in enumerate(self.cin)]

        # PE: positional embedding layer
        pointc = 128
        self.linearp = nn.Sequential(nn.Linear(3, 16),
                                     nn.ReLU(inplace=False),
                                     nn.Linear(16, pointc),
                                     nn.ReLU(inplace=False),
                                     nn.Linear(pointc, pointc))
        self.cin = [c+pointc for c in self.cin]

        # channel projection layer
        self.linear0 = nn.Sequential(nn.ReLU(inplace=True),
                                     nn.Linear(self.cin[0], self.cout))
        self.linear1 = nn.Sequential(nn.ReLU(inplace=True),
                                     nn.Linear(self.cin[1], self.cout))
        self.linear2 = nn.Sequential(nn.ReLU(inplace=True),
                                     nn.Linear(self.cin[2], self.cout))

        max_points = self.args.max_points + self.args.max_out_points
        if self.args.pool == 'embed_aap2':
            self.pooling = nn.Sequential(nn.ReLU(inplace=False),
                                         nn.Linear(max_points, 256),
                                         nn.ReLU(inplace=False),
                                         nn.Linear(256, 64),
                                         nn.ReLU(inplace=False),
                                         nn.Linear(64, 16)
                                         )
            self.cout *= 16

        self.mapping = nn.Sequential(nn.ReLU(inplace=False),
                                     nn.Linear(self.cout, 128),
                                     nn.ReLU(inplace=False),
                                     nn.Linear(128, 32),
                                     nn.ReLU(inplace=False),
                                     nn.Linear(32, 3),
                                     nn.Tanh())


    def initialize_integral(self, scale=-1):
        if scale == -1:
            self.integral = {0: [], 1: [], 2: []}
        else:
            self.integral[2-scale] = []


    def forward(self, query_feat, ref_feat, p3D_query, p3D_ref, J, scale=2):

        B, N, C = query_feat.size()

        # pid coefficient
        if self.args.pid_k == 'learnable' and self.args.base_optimizer == 'pid_adam':
            kp, ki, kd = self.pid()
            # self.kp, self.ki, self.kd = kp, ki, kd
        else:
            kp, ki, kd = self.kp, self.ki, self.kd

        # input on pose estimator
        res = kp * (query_feat - ref_feat)
        if self.args.Optimizer_input == 'concat':
            r = torch.cat([query_feat, ref_feat], dim=-1)
        elif self.args.Optimizer_input in ['resconcat']:    # default
            r = torch.cat([query_feat, ref_feat, res], dim=-1)
        else:
            r = res  # [B, C, H, W]

        # point input on pose estimator
        # if self.args.point_norm == 'zsn3':
        p3D_ref = (p3D_ref - p3D_ref.mean(dim=1, keepdim=True)) / (p3D_ref.std(dim=1, keepdim=True) + 1e-6)
        p3D_ref = p3D_ref.contiguous()
        p3D_ref_feat = self.linearp(p3D_ref)
        r = torch.cat([r, p3D_ref_feat], dim=-1)

        # D bramch
        if self.args.require_jac:
            if self.args.j_ver == 'long':
                J = J[..., 0].unsqueeze(-1)
            elif self.args.j_ver == 'lat':
                J = J[..., 1].unsqueeze(-1)
            elif self.args.j_ver == 'rot':
                J = J[..., 2].unsqueeze(-1)
            elif self.args.j_ver == 'sum':
                J = J.sum(dim=-1, keepdim=True)
            J = J.view(B, N, -1)
            J = nn.functional.normalize(J, dim=-1)
            r = torch.cat([r, kd * J], dim=-1)

        # I branch
        if self.args.require_int:
            integral = torch.cat(self.integral, dim=-1)
            r = torch.cat([r, ki * integral], dim=-1)

        # channel projection
        B, N, C = r.shape
        if C == self.cin[0]:
            x = self.linear0(r)
        elif C == self.cin[1]:
            x = self.linear1(r)
        elif C == self.cin[2]:
            x = self.linear2(r)

        # point pooling
        if self.args.pool == 'max':
            x = torch.max(x, 1, keepdim=True)[0]
        elif 'embed' in self.args.pool:
            x = x.contiguous().permute(0, 2, 1).contiguous()
            x = self.pooling(x)
        elif self.args.pool == 'gap':
            x = torch.mean(x, 1, keepdim=True)
        elif self.args.pool == 'all_ap':
            y = self.mapping(x)
            y = y.mean(dim=1)
            return y

        x = x.view(B, -1)
        y = self.mapping(x)  # [B, 3]

        return y


    # def get_jacobian_importance(self, J):
    #     # J: [B, N, -1, 3]
    #     J_lat = torch.sqrt(torch.sum(J[:, :, :, 0] ** 2, dim=2))  # shape: [B, N]
    #     J_long = torch.sqrt(torch.sum(J[:, :, :, 1] ** 2, dim=2))  # shape: [B, N]
    #     J_rot = torch.sqrt(torch.sum(J[:, :, :, 2] ** 2, dim=2))  # shape: [B, N]
    #
    #     mean_lat = torch.mean(J_lat, dim=1, keepdim=True)
    #     mean_long = torch.mean(J_long, dim=1, keepdim=True)
    #     mean_rot = torch.mean(J_rot, dim=1, keepdim=True)
    #
    #     J_lat_normalized = J_lat / (mean_lat + 1e-8)
    #     J_long_normalized = J_long / (mean_long + 1e-8)
    #     J_rot_normalized = J_rot / (mean_rot + 1e-8)
    #
    #     if self.args.sample_method == 'all':
    #         J_combined = J_lat_normalized + J_long_normalized + J_rot_normalized  # shape: [B, N]
    #     elif self.args.sample_method == 'trans':
    #         J_combined = J_lat_normalized + J_long_normalized
    #     elif self.args.sample_method == 'long':
    #         J_combined = J_long_normalized
    #     elif self.args.sample_method == 'lat':
    #         J_combined = J_lat_normalized
    #     return J_combined


    def sample_top_k_points(self, J_norm, k=1000):
        # J_norm: [B, N]
        B, N = J_norm.size()
        _, top_k_indices = torch.topk(J_norm, k, dim=1, largest=True, sorted=True)
        # top_k_indices = top_k_indices.unsqueeze(2).expand(B, k, 3)
        return top_k_indices