import logging
import math
from typing import Tuple

import torch
from torch import nn
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
from detectron2.modeling.backbone import Backbone
from detectron2.structures import ImageList, BitMasks

from mask2former.maskformer_model import MaskFormer
from mask2former.modeling.criterion import SetCriterion
from mask2former.modeling.matcher import HungarianMatcher

from ctvis.modeling.tracker import build_tracker
from ctvis.modeling.cl_plugin.simple_cl_plugin import build_cl_plugin
from ctvis.modeling.fusion_module import build_fusion_module
from ctvis.utils import retry_if_cuda_oom

logger = logging.getLogger(__name__)


@META_ARCH_REGISTRY.register()
class CTVISModel(MaskFormer):
    """Consistent Training for Online VIS Model.
    """
    @configurable
    def __init__(
            self,
            *,
            backbone: Backbone,
            sem_seg_head: nn.Module,
            criterion: nn.Module,
            num_queries: int,
            object_mask_threshold: float,
            overlap_threshold: float,
            metadata,
            size_divisibility: int,
            sem_seg_postprocess_before_inference: bool,
            pixel_mean: Tuple[float],
            pixel_std: Tuple[float],
            # video
            num_topk,
            num_frames,
            tracker,
            clip_num_frames,
            cl_plugin,
            test_interpolate_chunk_size,
            fusion_module,
            visualize_query_mode
    ):
        # The usage of Super(): https://blog.csdn.net/wanzew/article/details/106993425
        super(MaskFormer, self).__init__()
        self.backbone = backbone
        self.sem_seg_head = sem_seg_head
        self.criterion = criterion
        self.num_queries = num_queries
        self.overlap_threshold = overlap_threshold
        self.object_mask_threshold = object_mask_threshold
        self.metadata = metadata
        if size_divisibility < 0:
            # use backbone size_divisibility if not set
            size_divisibility = self.backbone.size_divisibility
        self.size_divisibility = size_divisibility
        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
        self.num_frames = num_frames

        # track head
        self.tracker = tracker
        self.clip_num_frames = clip_num_frames  # Avoid OOM
        self.test_interpolate_chunk_size = test_interpolate_chunk_size  # Avoid OOM
        self.cl_plugin = cl_plugin
        self.num_topk = num_topk

        # fusion_module
        self.fusion_module = fusion_module

        # visualize query mode
        self.visualize_query_mode = visualize_query_mode

    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())

        # Loss parameters:
        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT

        # loss weights
        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT

        # building criterion
        matcher = HungarianMatcher(
            cost_class=class_weight,
            cost_mask=mask_weight,
            cost_dice=dice_weight,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS)

        weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}

        if deep_supervision:
            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        losses = ["labels", "masks"]

        criterion = SetCriterion(
            sem_seg_head.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO)

        tracker = build_tracker(cfg)
        clip_num_frames = cfg.MODEL.CLIP_NUM_FRAMES

        cl_plugin = build_cl_plugin(cfg)

        num_topk = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        test_interpolate_chunk_size = cfg.TEST.TEST_INTERPOLATE_CHUNK_SIZE

        fusion_module = build_fusion_module(cfg)

        visualize_query_mode = cfg.MODEL.VISUALIZE_QUERY_MODE

        return {
            "backbone": backbone,
            "sem_seg_head": sem_seg_head,
            "criterion": criterion,
            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
            "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
            "sem_seg_postprocess_before_inference": True,
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            # video
            "num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
            "tracker": tracker,
            "clip_num_frames": clip_num_frames,
            "cl_plugin": cl_plugin,
            "num_topk": num_topk,
            "test_interpolate_chunk_size": test_interpolate_chunk_size,
            "fusion_module": fusion_module,
            "visualize_query_mode": visualize_query_mode
        }

    @property
    def device(self):
        return self.pixel_mean.device

    def forward(self, batched_inputs):
        if self.training:
            return self.train_model(batched_inputs)
        elif self.visualize_query_mode:
            return self.save_query_embedding(batched_inputs)
        else:
            return self.inference_model(batched_inputs)

    def pre_process(self, batched_inputs):
        images = []
        for video in batched_inputs:
            for frame in video["image"]:
                images.append(frame.to(self.device))
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.size_divisibility)
        return images

    def train_model(self, batched_inputs):
        images = self.pre_process(batched_inputs)

        features = self.backbone(images.tensor)
        det_outputs = self.sem_seg_head(features)

        det_outputs = self.fusion_module(det_outputs)

        # mask classification target
        if "instances" in batched_inputs[0]:
            gt_instances = []
            for video in batched_inputs:
                for frame in video["instances"]:
                    gt_instances.append(frame.to(self.device))
            targets = self.prepare_targets(gt_instances, images)
        else:
            targets = None

        # bipartite matching-based loss
        losses = self.criterion(det_outputs, targets)

        for k in list(losses.keys()):
            if k in self.criterion.weight_dict:
                losses[k] *= self.criterion.weight_dict[k]
            else:
                losses.pop(k)

        if self.cl_plugin is not None:
            losses.update(self.cl_plugin.train_loss(det_outputs, gt_instances, self.criterion.matcher))

        return losses

    def inference_model(self, batched_inputs):
        images = self.pre_process(batched_inputs)

        # Avoid Out-of-Memory
        num_frames = len(images)
        to_store = self.device if num_frames <= 20 else "cpu"

        if num_frames <= self.clip_num_frames:
            with torch.no_grad():
                features = self.backbone(images.tensor)
                det_outputs = self.sem_seg_head(features)
                det_outputs = self.fusion_module(det_outputs)
        else:
            pred_logits, pred_masks, pred_embeds, pred_queries = [], [], [], []
            num_clips = math.ceil(num_frames / self.clip_num_frames)  # math.ceil 
            for i in range(num_clips):
                start_idx = i * self.clip_num_frames
                end_idx = (i + 1) * self.clip_num_frames
                clip_images_tensor = images.tensor[start_idx:end_idx, ...]
                with torch.no_grad():
                    clip_features = self.backbone(clip_images_tensor)
                    clip_outputs = self.sem_seg_head(clip_features)
                    clip_outputs = self.fusion_module(clip_outputs)

                pred_logits.append(clip_outputs['pred_logits'])
                pred_masks.append(clip_outputs['pred_masks'])
                pred_embeds.append(clip_outputs['pred_embeds'])
                pred_queries.append(clip_outputs['pred_queries'])

            det_outputs = {
                'pred_logits': torch.cat(pred_logits, dim=0),
                'pred_masks': torch.cat(pred_masks, dim=0),
                'pred_embeds': torch.cat(pred_embeds, dim=0),
                'pred_queries': torch.cat(pred_queries, dim=0)
            }

        class_embed = self.sem_seg_head.predictor.class_embed
        outputs = self.tracker.inference(det_outputs, class_embed)

        if len(outputs['pred_logits']) == 0:
            video_output = {
                "image_size": (images.image_sizes[0], images.image_sizes[1]),
                "pred_scores": [],
                "pred_labels": [],
                "pred_masks": []
            }
            return video_output

        mask_cls_results = outputs["pred_logits"]
        mask_pred_results = outputs["pred_masks"]

        input_per_image = batched_inputs[0]
        image_tensor_size = images.tensor.shape
        image_size = images.image_sizes[0]  # image size without padding after data augmentation

        height = input_per_image.get("height", image_size[0])  # raw image size before data augmentation
        width = input_per_image.get("width", image_size[1])

        del outputs, batched_inputs, images, det_outputs

        video_output = self.inference_video(mask_cls_results, mask_pred_results, image_tensor_size, image_size,
                                            height, width, to_store)

        return video_output

    def inference_video(self, mask_cls_results, mask_pred_results, image_tensor_size, image_size, height, width, to_store):
        mask_cls_result = mask_cls_results[0]
        mask_pred_result = mask_pred_results[0]

        if len(mask_cls_result) > 0:
            scores = F.softmax(mask_cls_result, dim=-1)[:, :-1]
            labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(
                len(mask_cls_result), 1).flatten(0, 1)  # noqa
            # keep top-k predictions
            # num_topk = min()
            if self.sem_seg_head.num_classes == 25:
                self.num_topk = min(labels.shape[0], 50)
            scores_per_video, topk_indices = scores.flatten(0, 1).topk(self.num_topk, sorted=True)  # typically larger -> better

            labels_per_video = labels[topk_indices]
            topk_indices = topk_indices // self.sem_seg_head.num_classes

            mask_pred_result = mask_pred_result[topk_indices]

            num_frame = mask_pred_result.shape[1]

            masks_list = []
            numerator = torch.zeros(mask_pred_result.shape[0], dtype=torch.float, device=self.device)
            denominator = torch.zeros(mask_pred_result.shape[0], dtype=torch.float, device=self.device)
            for i in range(math.ceil(num_frame / self.test_interpolate_chunk_size)):
                temp_pred_mask = mask_pred_result[:, i * self.test_interpolate_chunk_size:(i + 1) * self.test_interpolate_chunk_size, ...] # noqa
                temp_pred_mask = retry_if_cuda_oom(F.interpolate)(
                    temp_pred_mask,
                    size=(image_tensor_size[-2], image_tensor_size[-1]),
                    mode="bilinear",
                    align_corners=False)
                temp_pred_mask = temp_pred_mask[:, :, : image_size[0], : image_size[1]]

                temp_pred_mask = retry_if_cuda_oom(F.interpolate)(temp_pred_mask, size=(height, width), mode="bilinear", align_corners=False) # noqa
                masks = (temp_pred_mask > 0.).float()
                # this operation will bring gains. Very good!!!!!!!!!! stolen from mmdet.mask2former
                numerator += (temp_pred_mask.sigmoid() * masks).flatten(1).sum(1)
                denominator += masks.flatten(1).sum(1)

                masks_list.append(masks.bool().to(to_store))
            scores_per_video *= (numerator / (denominator + 1e-6))
            masks = torch.cat(masks_list, dim=1)

            out_scores = scores_per_video.tolist()
            out_labels = labels_per_video.tolist()
            out_masks = [m for m in masks.cpu()]
        else:
            out_scores = []
            out_labels = []
            out_masks = []

        video_output = {
            "image_size": (height, width),
            "pred_scores": out_scores,
            "pred_labels": out_labels,
            "pred_masks": out_masks,
        }

        return video_output

    def prepare_targets(self, targets, images):
        h_pad, w_pad = images.tensor.shape[-2:]
        new_targets = []
        for targets_per_image in targets:
            # pad gt
            if isinstance(targets_per_image.gt_masks, BitMasks):
                gt_masks = targets_per_image.gt_masks.tensor
            else:
                gt_masks = targets_per_image.gt_masks
            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks

            gt_instance_ids = targets_per_image.gt_ids
            valid_index = gt_instance_ids != -1

            new_targets.append(
                {
                    "labels": targets_per_image.gt_classes[valid_index],
                    "masks": padded_masks[valid_index],
                    "ids": gt_instance_ids[valid_index],  # for MoCo purpose
                }
            )
        return new_targets

    @torch.no_grad()
    def save_query_embedding(self, batched_inputs):
        """
        Get cor. query embedding and its gt ids
        """
        images = self.pre_process(batched_inputs)

        # Avoid Out-of-Memory
        num_frames = len(images)
        to_store = self.device if num_frames <= 20 else "cpu"

        if num_frames <= self.clip_num_frames:
            with torch.no_grad():
                features = self.backbone(images.tensor)
                det_outputs = self.sem_seg_head(features)
        else:
            pred_logits, pred_masks, pred_embeds, pred_queries = [], [], [], []
            num_clips = math.ceil(num_frames / self.clip_num_frames)  # math.ceil 
            for i in range(num_clips):
                start_idx = i * self.clip_num_frames
                end_idx = (i + 1) * self.clip_num_frames
                clip_images_tensor = images.tensor[start_idx:end_idx, ...]
                with torch.no_grad():
                    clip_features = self.backbone(clip_images_tensor)
                    clip_outputs = self.sem_seg_head(clip_features)

                pred_logits.append(clip_outputs['pred_logits'])
                pred_masks.append(clip_outputs['pred_masks'])
                pred_embeds.append(clip_outputs['pred_embeds'])
                pred_queries.append(clip_outputs['pred_queries'])

            det_outputs = {
                'pred_logits': torch.cat(pred_logits, dim=0),
                'pred_masks': torch.cat(pred_masks, dim=0),
                'pred_embeds': torch.cat(pred_embeds, dim=0),
                'pred_queries': torch.cat(pred_queries, dim=0)
            }

        # mask classification target
        if "instances" in batched_inputs[0]:
            gt_instances = []
            for video in batched_inputs:
                for frame in video["instances"]:
                    gt_instances.append(frame.to(self.device))
            targets = self.prepare_targets(gt_instances, images)
        else:
            targets = None

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.criterion.matcher(det_outputs, targets)

        output_dict = {"output": det_outputs,
                       "indices": indices,
                       "targets": targets,
                       "inputs": batched_inputs}


        output = output_dict["output"]
        indices = output_dict["indices"]
        targets = output_dict["targets"]

        query_list = []
        label_list = []

        frame_num = output["pred_embeds"].shape[0]
        for frame_i in range(frame_num):
            pred_embed = output["pred_embeds"][frame_i]
            indice = indices[frame_i]
            target = targets[frame_i]
            if target["ids"].shape[0] == 0:
                continue
            query_label = pred_embed.new_full((pred_embed.shape[0],), -1, dtype=torch.long)

            query_label[indice[0]] = target["ids"][indice[1]]

            keep_index = query_label != -1

            pred_embed = pred_embed[keep_index]
            query_label = query_label[keep_index]

            query_list.append(pred_embed)
            label_list.append(query_label)

        data = torch.cat(query_list, dim=0).detach().cpu().numpy()
        label = torch.cat(label_list, dim=0).detach().cpu().numpy()

        outs = {
            "data": data,
            "label": label
        }

        return outs
