
import copy
import math

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_

from misc.detr_utils.misc import  inverse_sigmoid
from pdvc.ops.modules import MSDeformAttn
d_model = 768
def inverse_sigmoid(x, eps=1e-5):
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1/x2)



def get_proposal_pos_embed(proposals):
    num_pos_feats = 256
    temperature = 10000
    scale = 2 * math.pi

    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
    # N, L, 2
    proposals = proposals.sigmoid() * scale
    # N, L, 2, 256
    pos = proposals[:, :, :, None] / dim_t
    # N, L, 2, 128, 2
    pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
    return pos

def prepare_decoder_input_proposal(gt_reference_points):
    pos_trans_norm = nn.LayerNorm(d_model * 2)
    pos_trans = nn.Linear(d_model, d_model * 2)
    resize_cluster = nn.Linear(256, d_model)
    topk_coords_unact = inverse_sigmoid(gt_reference_points)
    reference_points = gt_reference_points
    init_reference_out = reference_points
    pos_trans_out = pos_trans_norm(pos_trans(resize_cluster(get_proposal_pos_embed(topk_coords_unact))))
    query_embed, tgt = torch.chunk(pos_trans_out, 2, dim=2)
    return init_reference_out, tgt, reference_points, query_embed