import torch
import torch.nn as nn
import torch.nn.functional as F
import diffdist


class TCVM(nn.module):

    def __init__(self, backbone, K=4, T=32, augmentations=lambda x: x):
        """
        :param backbone: The backbone to use for training.
        :param T: The number of frames in each video.
        :param K: number of clips cut from one video.
        :param augmentations: a function that takes a batch of clips and returns a batch of augmented clips
        :return:
        """

        self.backbone = backbone
        self.K = K
        self.augmentations = augmentations
        self.feat_diff_fusion = nn.Sequential(
            nn.Conv3d(self.backbone.features_dim,
                      self.backbone.features_dim // T * self.K,
                      (T // self.K, 1, 1)),
            nn.BatchNorm1d(self.backbone.features_dim),
            nn.ReLU(inplace=True),
        )

    def _cut_and_augment(self, videos):
        """
        :param videos: a list of video clips
        :return: a list of augmented video clips
        """

        clips = []
        n, c, v, h, w = videos.shape
        for i in range(self.K):
            strt = i * v // self.K
            end = (i + 1) * v // self.K
            clips.append(self.augmentations(videos[:, :, strt:end, ...]))

        return clips

    @torch.no_grad()
    def _batch_shuffle_ddp(self, videos):
        """
        gather videos from different gpus and shuffle between them
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        assert self.K == 4

        videos_gather = []
        batch_size_this = videos[0].shape[0]

        for i in range(self.K):
            videos_gather.append(concat_all_gather(videos[i]))

        batch_size_all = videos_gather[0].shape[0]
        num_gpus = batch_size_all // batch_size_this

        n, c, v, h, w = videos_gather[0].shape

        permute = torch.randperm(self.K * n).type(torch.long).cuda()

        torch.distributed.broadcast(permute, src=0)
        videos_gather = torch.cat(videos_gather, dim=0)
        videos_gather = videos_gather[permute, :, :, :, :]
        seg1 = torch.cat([videos_gather[0:n], videos_gather[n:2 * n]], dim=2)
        seg2 = torch.cat([videos_gather[2 * n:3 * n], videos_gather[3 * n:]],
                         dim=2)
        videos_gather = torch.cat([seg1, seg2], dim=2)

        bs = videos_gather.shape[0] // num_gpus
        gpu_idx = torch.distributed.get_rank()

        return videos_gather[bs * gpu_idx:bs *
                             (gpu_idx + 1)], permute, n, gpu_idx

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def _feature_diff(self, features):
        n, c, v, h, w = features.shape
        feat_diff = features
        feat_diff[:, :, :v // 2,
                  ...] = features[:, :, :v // 2 - 1,
                                  ...] - features[:, :, 1:v // 2, ...]
        return feat_diff

    def forward(self, videos):
        """
        Args:
            videos: [batch_size, c, v, h, w]
        """

        assert self.K == 4

        n, c, v, h, w = videos.shape

        videos = self._cut_and_augment(videos)
        videos_gather, permute, bs_all, gpu_idx = self._batch_gather_ddp(
            videos)
        feature = self.backbone(videos_gather)

        feat_diff = self._feature_diff(feature)
        feat_diff = F.avg_pool3d(feat_diff, kernel_size=(1, h, w))
        seg1, seg2, seg3, seg4 = feat_diff.split([1, 1, 1, 1], dim=2)
        feat_diff = torch.cat([seg1, seg2, seg3, seg4], dim=0).view(n * 4, -1)
        feat_diff = self.feat_diff_fusion(feat_diff)
        feat_diff = concat_all_gather(feat_diff)

        label = permute % bs_all
        feat_diff = F.normalize(feat_diff, dim=1)

        return feat_diff, label


def concat_all_gather(tensor):
    tensors_gather = [
        torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())
    ]
    tensors_gather = diffdist.functional.all_gather(tensors_gather,
                                                    tensor,
                                                    next_backprop=None,
                                                    inplace=True)

    output = torch.cat(tensors_gather, dim=0)
    return output
