import numpy as np, math
import torch
from torch import nn
import numpy as np
from PIL import Image
from mmengine import MODELS
import torchvision.transforms as transforms
import cv2
import open3d as o3d
import matplotlib.pyplot as plt
import torch.nn.functional as F

from ..encoder.gaussianformer.utils import safe_get_quaternion, batch_quaternion_multiply, get_rotation_matrix, safe_sigmoid

LOGIT_MAX = 0.99

def safe_inverse_sigmoid(tensor): # 逆 Sigmoid 函数
    tensor = torch.clamp(tensor, 1 - LOGIT_MAX, LOGIT_MAX)
    return torch.log(tensor / (1 - tensor))

def bin_depths(depth_map, mode, depth_min, depth_max, num_bins, target=False):
    """
    Converts depth map into bin indices
    Args:
        depth_map [torch.Tensor(H, W)]: Depth Map
        mode [string]: Discretiziation mode (See https://arxiv.org/pdf/2005.13423.pdf for more details)
            UD: Uniform discretiziation
            LID: Linear increasing discretiziation
            SID: Spacing increasing discretiziation
        depth_min [float]: Minimum depth value
        depth_max [float]: Maximum depth value
        num_bins [int]: Number of depth bins
        target [bool]: Whether the depth bins indices will be used for a target tensor in loss comparison
    Returns:
        indices [torch.Tensor(H, W)]: Depth bin indices
    """
    if mode == "UD":
        bin_size = (depth_max - depth_min) / num_bins
        indices = (depth_map - depth_min) / bin_size
    elif mode == "LID":
        bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins))
        indices = -0.5 + 0.5 * torch.sqrt(1 + 8 * (depth_map - depth_min) / bin_size)
    elif mode == "SID":
        indices = (
            num_bins
            * (torch.log(1 + depth_map) - math.log(1 + depth_min))
            / (math.log(1 + depth_max) - math.log(1 + depth_min))
        )
    else:
        raise NotImplementedError

    if target:
        # Remove indicies outside of bounds (-2, -1, 0, 1, ..., num_bins, num_bins +1) --> (num_bins, num_bins, 0, 1, ..., num_bins, num_bins)
        mask = (indices < 0) | (indices > num_bins) | (~torch.isfinite(indices))
        indices[mask] = num_bins

        # Convert to integer
        indices = indices.type(torch.int64)
    return indices.long()

def sample_3d_feature(feature_3d, pix_xy, pix_z, fov_mask):
    """
    Args:
        feature_3d (torch.tensor): 3D feature, shape (C, D, H, W).
        pix_xy (torch.tensor): Projected pix coordinate, shape (N, 2).
        pix_z (torch.tensor): Projected pix depth coordinate, shape (N,).
    
    Returns:
        torch.tensor: Sampled feature, shape (N, C)
    """
    pix_x, pix_y = pix_xy[:, 0][fov_mask], pix_xy[:, 1][fov_mask]
    pix_z = pix_z[fov_mask].to(pix_y.dtype)
    ret = feature_3d[:, pix_z, pix_y, pix_x].T
    return ret

@MODELS.register_module()
class GaussianLifterOnline(nn.Module):
    def __init__(
        self,
        embed_dims, # 96
        num_anchor=25600, # 21600
        anchor=None,
        anchor_grad=False, 
        feat_grad=False,
        semantic_dim=0, # 13
        include_opa=True,
        include_v=False,
    ):
        super().__init__()
        self.embed_dims = embed_dims
        if isinstance(anchor, str):
            anchor = np.load(anchor)
        elif isinstance(anchor, (list, tuple)):
            anchor = np.array(anchor)
        elif anchor is None:
            total_anchor = num_anchor
            xyz = torch.rand(num_anchor, 3, dtype=torch.float)
            assert xyz.shape[0] == num_anchor
            xyz = safe_inverse_sigmoid(xyz)
    
            scale = torch.rand_like(xyz)
            scale = safe_inverse_sigmoid(scale)
            rots = torch.zeros(num_anchor, 4, dtype=torch.float)
            rots[:, 0] = 1
            opacity = safe_inverse_sigmoid(0.1 * torch.ones((
                num_anchor, int(include_opa)), dtype=torch.float))
            semantic = torch.randn(num_anchor, semantic_dim, dtype=torch.float)
            self.semantic_dim = semantic_dim
            
            anchor = torch.cat([xyz, scale, rots, opacity, semantic], dim=-1)

        # self.num_anchor = min(len(anchor), num_anchor)
        # anchor = anchor[:num_anchor]
        self.num_anchor = total_anchor
        self.anchor = nn.Parameter(
            torch.tensor(anchor, dtype=torch.float32),
            requires_grad=anchor_grad,
        )
        self.anchor_init = anchor
        # self.instance_feature = nn.Parameter(
        #     torch.zeros([self.anchor.shape[0], self.embed_dims]),
        #     requires_grad=feat_grad,
        # ) 
        self.instance_feature_layer = nn.Linear(
            3 + 3 + 4 + int(include_opa) + semantic_dim, embed_dims)
        

    def init_weight(self):
        self.anchor.data = self.anchor.data.new_tensor(self.anchor_init)
        # if self.instance_feature.requires_grad:
        #     torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)
    
    def forward(self, scenemeta, gaussian_pool, global_mask_thistime, flag_depthbranch, flag_depthanything_as_gt, depthnet_output, mlvl_img_feats, metas):
        
        batch_size = mlvl_img_feats[0].shape[0]
        
        anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))
        
        # get anchors overlapped with the gaussian_pool
        gaussian_pool_old = gaussian_pool
        new_anchor_ratio = 1
        if gaussian_pool.shape[1] > 0:
            # get the overlap ratio
            mask_old = global_mask_thistime
            mask_thisframe = metas[0]['mask_in_global_from_this'].to(torch.bool)
            mask_overlap = (mask_old & mask_thisframe)
            mask_new_thisframe = (~mask_old) & mask_thisframe
            anchor_overlap_num = mask_overlap.sum().item()
            anchor_new_num = mask_thisframe.sum().item() - anchor_overlap_num
            new_anchor_ratio = anchor_new_num / (mask_thisframe.sum().item())
            # get anchor from the pool
            gaussian_pool_old = gaussian_pool_old.squeeze(0)
            gaussian_pool_xyz = gaussian_pool_old[:, :3] # world coord
            
            world2cam = metas[0]['world2cam'].to(torch.float32)
            gaussian_pool_xyz_ = torch.cat([gaussian_pool_xyz, torch.ones((gaussian_pool_xyz.shape[0], 1), device=gaussian_pool_xyz.device)], dim=1).to(torch.float32)
            gaussian_pool_cam_ = (world2cam @ gaussian_pool_xyz_.unsqueeze(-1)).squeeze(-1)
            gaussian_pool_cam = gaussian_pool_cam_[:, :3]
            cam_k = metas[0]['cam_k'].to(torch.float32)
            gaussian_pool_cam_x = gaussian_pool_cam[:, 0]
            gaussian_pool_cam_y = gaussian_pool_cam[:, 1]
            gaussian_pool_cam_z = gaussian_pool_cam[:, 2]
            mask1 = gaussian_pool_cam_z > 1e-6
            gaussian_pool_cam_z[~mask1] = 1e-6
            gaussian_pool_cam_x = gaussian_pool_cam_x / gaussian_pool_cam_z
            gaussian_pool_cam_y = gaussian_pool_cam_y / gaussian_pool_cam_z
            gaussian_pool_pix_x = torch.floor(cam_k[0, 0] * gaussian_pool_cam_x + cam_k[0, 2]).to(torch.int32)
            gaussian_pool_pix_y = torch.floor(cam_k[1, 1] * gaussian_pool_cam_y + cam_k[1, 2]).to(torch.int32)
            mask2 = (gaussian_pool_pix_x >= 0) & (gaussian_pool_pix_x < 640) & (gaussian_pool_pix_y >= 0) & (gaussian_pool_pix_y < 480)
            mask_all = mask1 & mask2
            
            vox_near_world = metas[0]['vox_origin']
            vox_far_world = metas[0]['vox_origin'] + metas[0]['scene_size']
            gaussian_pool_mask = (gaussian_pool_xyz[:, 0] >= vox_near_world[0]) & (gaussian_pool_xyz[:, 0] <= vox_far_world[0]) & (gaussian_pool_xyz[:, 1] >= vox_near_world[1]) & (gaussian_pool_xyz[:, 1] <= vox_far_world[1]) & (gaussian_pool_xyz[:, 2] >= vox_near_world[2]) & (gaussian_pool_xyz[:, 2] <= vox_far_world[2])
            gaussian_pool_mask_detach = mask_all & gaussian_pool_mask
            
            gaussian_reused = gaussian_pool_old[gaussian_pool_mask]
            gaussian_unchange = gaussian_pool_old[~gaussian_pool_mask_detach]
            gaussian_pool_new = gaussian_unchange.unsqueeze(0)
            
            gaussian_reused_tag = gaussian_reused[..., 23]
            # worldgaussian to camanchor
            gaussian_reused = gaussian_reused[..., :-1]
            gaussian_means_world = gaussian_reused[:, :3]
            gaussian_scales = gaussian_reused[:, 3:6]
            gaussian_rotations_world = gaussian_reused[:, 6:10]
            gaussian_opacities = gaussian_reused[:, 10:11]
            gaussian_semantics = gaussian_reused[:, 11:]
            
            gaussian_means_world_ = torch.cat([gaussian_means_world, torch.ones((gaussian_means_world.shape[0], 1), device=gaussian_means_world.device)], dim=1)
            world2cam = metas[0]['world2cam'].to(torch.float32)
            gaussian_means_cam = (world2cam @ gaussian_means_world_.unsqueeze(-1)).squeeze(-1)
            gaussian_means_cam = gaussian_means_cam[:, :3]
            nyu_pc_range = metas[0]['cam_vox_range']
            gaussian_mask = (gaussian_means_cam[:, 0] >= nyu_pc_range[0]) & (gaussian_means_cam[:, 0] <= nyu_pc_range[3]) & (gaussian_means_cam[:, 1] >= nyu_pc_range[1]) & (gaussian_means_cam[:, 1] <= nyu_pc_range[4]) & (gaussian_means_cam[:, 2] >= nyu_pc_range[2]) & (gaussian_means_cam[:, 2] <= nyu_pc_range[5])
            gaussian_means_cam = gaussian_means_cam[gaussian_mask]
            gaussian_scales = gaussian_scales[gaussian_mask]
            gaussian_rotations_world = gaussian_rotations_world[gaussian_mask]
            gaussian_opacities = gaussian_opacities[gaussian_mask]
            gaussian_semantics = gaussian_semantics[gaussian_mask]
            
            w2c_rot = metas[0]['world2cam'][:3, :3].to(torch.float32)
            w2c_quat = safe_get_quaternion(w2c_rot.unsqueeze(0)).squeeze(0)
            gaussian_rotations_cam = batch_quaternion_multiply(w2c_quat, gaussian_rotations_world)
            
            anchor_means_cam = (gaussian_means_cam - nyu_pc_range[:3]) / (nyu_pc_range[3:] - nyu_pc_range[:3]) # 0-1
            anchor_reused = torch.cat([anchor_means_cam, gaussian_scales, gaussian_rotations_cam, gaussian_opacities, gaussian_semantics], dim=-1).unsqueeze(0)
            
        else:
            gaussian_pool_new = gaussian_pool
            anchor_reused = torch.empty((1, 0, 23), dtype=torch.float32, device=anchor.device) 
            mask_new_thisframe = metas[0]['mask_in_global_from_this'].to(torch.bool)
        
        if flag_depthbranch:
            if flag_depthanything_as_gt:
                z = depthnet_output
            else:
                z = metas[0]['depth_gt']
                # img -> depth -> point cloud
            
            f_l_x = torch.tensor(metas[0]['cam_k'][0, 0]).cuda()
            f_l_y = torch.tensor(metas[0]['cam_k'][1, 1]).cuda()
            c_x = torch.tensor(metas[0]['cam_k'][0, 2]).cuda()
            c_y = torch.tensor(metas[0]['cam_k'][1, 2]).cuda()
            old_width = torch.tensor(640).cuda()
            old_height = torch.tensor(480).cuda()
            x, y = torch.meshgrid(torch.arange(old_width, dtype=torch.float32).cuda(), torch.arange(old_height, dtype=torch.float32).cuda())
            x = x.permute(1, 0)
            y = y.permute(1, 0)
            x = (x - c_x) / f_l_x
            y = (y - c_y) / f_l_y
            points = torch.stack((x * z, y * z, z), dim=-1).reshape(-1, 3)
            points_ = torch.cat((points, torch.ones((points.shape[0], 1), device=points.device)), dim=1).to(torch.float32)
            cam2world = metas[0]['cam2world'].to(torch.float32)
            points_world_ = (cam2world @ points_.unsqueeze(-1)).squeeze(-1)
            points_world = points_world_[:, :3]
            
        # fusion (depth_init & anchor_reused_init)
        # if anchor_reused.shape[1] > 0:
            # # get reused mask in global
            # anchor_reused_mask_inglobal = torch.zeros_like(scenemeta['global_mask']).to(torch.bool)
            scene_near = scenemeta['global_scene_origin']
            scene_far = scenemeta['global_scene_origin'] + scenemeta['global_scene_size']
            scene_dim = scenemeta['global_scene_dim']
            # anchor_reused_xyz_01_cam = anchor_reused[..., :3].squeeze(0)
            # nyu_pc_range = metas[0]['cam_vox_range']
            # anchor_reused_xyz_cam = anchor_reused_xyz_01_cam * (nyu_pc_range[3:] - nyu_pc_range[:3]) + nyu_pc_range[:3]
            # cam2world = metas[0]['cam2world'].to(torch.float32)
            # anchor_reused_xyz_cam_ = torch.cat([anchor_reused_xyz_cam, torch.ones((anchor_reused_xyz_cam.shape[0], 1), device=anchor_reused_xyz_cam.device)], dim=1).to(torch.float32)
            # anchor_reused_xyz_world_ = (cam2world @ anchor_reused_xyz_cam_.unsqueeze(-1)).squeeze(-1)
            # anchor_reused_xyz_world = anchor_reused_xyz_world_[:, :3]
            
            # anchor_reused_xyz_world_mask = (anchor_reused_xyz_world[:, 0] >= scene_near[0]) & (anchor_reused_xyz_world[:, 0] <= scene_far[0]) & (anchor_reused_xyz_world[:, 1] >= scene_near[1]) & (anchor_reused_xyz_world[:, 1] <= scene_far[1]) & (anchor_reused_xyz_world[:, 2] >= scene_near[2]) & (anchor_reused_xyz_world[:, 2] <= scene_far[2])
            # anchor_reused_xyz_world_inroom = anchor_reused_xyz_world[anchor_reused_xyz_world_mask]
            # anchor_reused_xyz_world_index = torch.floor((anchor_reused_xyz_world_inroom - scene_near) / (scene_far - scene_near) * scene_dim)
            
            # reused_x_index = anchor_reused_xyz_world_index[:, 0]
            # reused_y_index = anchor_reused_xyz_world_index[:, 1]
            # reused_z_index = anchor_reused_xyz_world_index[:, 2]
            # anchor_reused_mask_inglobal[reused_x_index.long(), reused_y_index.long(), reused_z_index.long()] = True
            
            # mask_thisframe = metas[0]['mask_in_global_from_this'].to(torch.bool)
            # reused_now_overlap = (anchor_reused_mask_inglobal & mask_thisframe)
            # new_now_global = ((~anchor_reused_mask_inglobal) & mask_thisframe)
            # ratio_now = (mask_thisframe.sum().item() - reused_now_overlap.sum().item()) / mask_thisframe.sum().item()
            
            # get depth_init mask in global
            epison = 1e-3
            points_world_mask = (points_world[:, 0] >= (scene_near[0]+epison)) & (points_world[:, 0] < (scene_far[0]-epison)) & (points_world[:, 1] >= (scene_near[1]+epison)) & (points_world[:, 1] < (scene_far[1]-epison)) & (points_world[:, 2] >= (scene_near[2]+epison)) & (points_world[:, 2] < (scene_far[2]-epison))
            points_world_inroom = points_world[points_world_mask]
            points_world_index = torch.floor((points_world_inroom - scene_near) / (scene_far - scene_near) * scene_dim)
            
            points_world_index_x = points_world_index[:, 0]
            points_world_index_y = points_world_index[:, 1]
            points_world_index_z = points_world_index[:, 2]
            points_in_new_region_mask = (mask_new_thisframe[points_world_index_x.long(), points_world_index_y.long(), points_world_index_z.long()] == True)
            # dep_mask = (new_now_global[points_world_index_x.long(), points_world_index_y.long(), points_world_index_z.long()] == True)
            
            points_old = points[points_world_mask]
            points = points_old[points_in_new_region_mask]
            points_left = points_old[~points_in_new_region_mask]
        # endfusion
            
        
        points_cam = points
        nyu_pc_range = metas[0]['cam_vox_range']
        num_depth = metas[0]['num_depth']
        
        num_depth_new = num_depth - anchor_reused.shape[1]
        
        if num_depth_new <= (num_depth * new_anchor_ratio):
            num_depth_new = int(num_depth * new_anchor_ratio)
            
        if points_cam.shape[0] < num_depth_new:
            points_cam_left = points_left[torch.randperm(points_left.shape[0])[:(num_depth_new - points_cam.shape[0])]]
            points_cam = torch.cat([points_cam, points_cam_left], dim=0)
        else:
            points_cam = points_cam[torch.randperm(points_cam.shape[0])[:num_depth_new]]
        
        # points_cam = points_cam[torch.randperm(points_cam.shape[0])[:num_depth]]
        points_cam = torch.clamp(points_cam, nyu_pc_range[:3], nyu_pc_range[3:])
        points_cam = (points_cam - nyu_pc_range[:3]) / (nyu_pc_range[3:] - nyu_pc_range[:3]) # 0-1
        
        anchor_points = points_cam
        anchor_points = anchor_points.float().unsqueeze(0).to(anchor.device)
        anchor_points_ = anchor[:, :1, 3:11].clone().repeat(1, anchor_points.shape[1], 1) # b, n, c
        anchor_points = torch.cat([
            safe_inverse_sigmoid(torch.clamp(anchor_points, 0.001, 0.999)),
            anchor_points_,
            torch.randn(*anchor_points.shape[:2], self.semantic_dim, dtype=anchor_points.dtype, device=anchor_points.device)
        ], dim=-1)
        
        # anchor = torch.cat([anchor, anchor_points], dim=1)
        anchor = anchor_points
        anchor_new_tag = torch.zeros((1, anchor.shape[1], 1), dtype=torch.float32, device=anchor.device)
        # anchor_tag = anchor_new_tag
        if anchor_reused.shape[1] > 0:
            anchor_reused = anchor_reused.squeeze(0)
            anchor_reused_xyz = safe_inverse_sigmoid(torch.clamp(anchor_reused[..., :3], 0.001, 0.999))
            anchor_reused_scale = safe_inverse_sigmoid(anchor_reused[..., 3:6])
            anchor_reused_rot = anchor_reused[..., 6:10]
            anchor_reused_opa = safe_inverse_sigmoid(anchor_reused[..., 10:11])
            anchor_reused_sem = anchor_reused[..., 11:]
            anchor_reused = torch.cat([anchor_reused_xyz, anchor_reused_scale, anchor_reused_rot, anchor_reused_opa, anchor_reused_sem], dim=-1).unsqueeze(0)
        
            anchor_reused_tag = gaussian_reused_tag.unsqueeze(0).unsqueeze(-1)
          
        if anchor_reused.shape[1] > 0:
            anchor_tag = torch.cat([anchor_new_tag, anchor_reused_tag], dim=1)
            anchor = torch.cat([anchor, anchor_reused], dim=1)
        else:
            anchor_tag = anchor_new_tag
         
        
        
        instance_feature = self.instance_feature_layer(anchor)
        
        return anchor, instance_feature, None, None, None, gaussian_pool_new, anchor_tag