import numpy as np, math
import torch
from torch import nn
import numpy as np

from mmengine import MODELS
from mmcv.utils import build_from_cfg

LOGIT_MAX = 0.99

def safe_inverse_sigmoid(tensor): # 逆 Sigmoid 函数
    tensor = torch.clamp(tensor, 1 - LOGIT_MAX, LOGIT_MAX)
    return torch.log(tensor / (1 - tensor))


@MODELS.register_module()
class GaussianLifter(nn.Module):
    def __init__(
        self,
        embed_dims, # 96
        num_anchor=25600, # 21600
        anchor=None,
        anchor_grad=True, 
        feat_grad=False,
        semantic_dim=0, # 13
        include_opa=True,
        include_v=False,
    ):
        super().__init__()
        self.embed_dims = embed_dims
        if isinstance(anchor, str):
            anchor = np.load(anchor)
        elif isinstance(anchor, (list, tuple)):
            anchor = np.array(anchor)
        elif anchor is None:
            xyz = torch.rand(num_anchor, 3, dtype=torch.float)
            xyz = safe_inverse_sigmoid(xyz)
            scale = torch.ones_like(xyz) * 0.5
            scale = safe_inverse_sigmoid(scale)
            rots = torch.zeros(num_anchor, 4, dtype=torch.float)
            rots[:, 0] = 1
            opacity = safe_inverse_sigmoid(0.1 * torch.ones((
                num_anchor, int(include_opa)), dtype=torch.float))
            semantic = torch.randn(num_anchor, semantic_dim, dtype=torch.float)
            # velocity = torch.zeros(num_anchor, 2 * int(include_v), dtype=torch.float)
            # anchor = torch.cat([xyz, scale, rots, opacity, semantic, velocity], dim=-1)
            anchor = torch.cat([xyz, scale, rots, opacity, semantic], dim=-1)

        self.num_anchor = min(len(anchor), num_anchor)
        anchor = anchor[:num_anchor]
        self.anchor = nn.Parameter(
            torch.tensor(anchor, dtype=torch.float32),
            requires_grad=anchor_grad,
        )
        self.anchor_init = anchor
        self.instance_feature = nn.Parameter(
            torch.zeros([self.anchor.shape[0], self.embed_dims]),
            requires_grad=feat_grad,
        ) 

    def init_weight(self):
        self.anchor.data = self.anchor.data.new_tensor(self.anchor_init)
        if self.instance_feature.requires_grad:
            torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)

    def forward(self, mlvl_img_feats):
        batch_size = mlvl_img_feats[0].shape[0]
        anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))
        instance_feature = torch.tile(
            self.instance_feature[None], (batch_size, 1, 1)
        )
        return anchor, instance_feature