
import copy
import math

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_

from misc.detr_utils.misc import  inverse_sigmoid
from pdvc.ops.modules import MSDeformAttn
def prepare_encoder_inputs(self, srcs, masks, pos_embeds, cl_labels):
    if self.ablation == "all" or self.ablation == "EPE":
        # prepare input for encoder
        src_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        temporal_shapes = []
        cl_labels_flatten = []
        for lvl, (src, mask, pos_embed, cl_label) in enumerate(zip(srcs, masks, pos_embeds, cl_labels)):
            """
            lvl: (bs, )
            src: (bs, c, L )
            mask: (bs, L)
            pos_embed: (bs, d_m, L)
            """
            bs, c, L = src.shape
            temporal_shapes.append(L)
            src = src.transpose(1, 2)  # （bs, L, c）
            pos_embed = pos_embed.transpose(1, 2)  # #（bs, L, d_m）
            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
            cl_label = cl_label.transpose(1, 2)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            src_flatten.append(src)
            mask_flatten.append(mask)
            cl_labels_flatten.append(cl_label)
        src_flatten = torch.cat(src_flatten, 1)  # (lvl_num, bs, wh, c)
        mask_flatten = torch.cat(mask_flatten, 1)  # (lvl_num, bs, wh)
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)  # (lvl_num, bs, wh, d_m)
        cl_labels_flatten = torch.cat(cl_labels_flatten, 1)
        temporal_shapes = torch.as_tensor(temporal_shapes, dtype=torch.long, device=src_flatten.device)  # (lvl_num, 2)
        level_start_index = torch.cat((temporal_shapes.new_zeros((1,)), temporal_shapes.cumsum(0)[
                                                :-1]))  # prod: [w0h0, w0h0+w1h1, w0h0+w1h1+w2h2, ...]
        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks],
                                1)  # (bs, lvl_num, 2), where 2 means (h_rate, and w_rate)， all values <= 1

        return src_flatten, temporal_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, cl_labels_flatten