# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('..')
from hyper.hypernetwork import EquiRotate
from hyper.models import InvertibleNet, FiLMAlign, CondAlignMLP, MultiLinearAlign, WarpAlign
from scipy.spatial.transform import Rotation as R

import torch.distributed as dist
from classy_vision.generic.distributed_util import (
    convert_to_distributed_tensor,
    convert_to_normal_tensor,
    is_distributed_training_run,
)


class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, out_dim, bias=False),
                                 nn.BatchNorm1d(out_dim)
                                 )
    def forward(self, x):
        return self.net(x)
    

class PredictionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True), 
                                 nn.Linear(hidden_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True), 
                                 nn.Linear(hidden_dim, out_dim))
    def forward(self, x):
        return self.net(x)
        

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x

class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(
        self,
        args,
        base_encoder,
        dim: int = 128,
        K: int = 65536,
        m: float = 0.999,
        T: float = 0.07,
        mlp: bool = False,
    ) -> None:
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.args = args
        self.order = 4
        self.gie = args.gie
        
        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)
        
        dim_mlp = self.encoder_q.fc.weight.shape[1]          
        
        if mlp:  # hack: brute-force replacement
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )
        
        if self.gie:
            self.projector_q = self.encoder_q.fc
            self.projector_k = self.encoder_k.fc
            self.encoder_q.fc = nn.Identity()
            self.encoder_k.fc = nn.Identity()
            
            if self.args.stop_gradient1:
                self.predictor_q = PredictionMLP(dim_mlp, args.pred_hidden_dim, 4)
                self.predictor_k = PredictionMLP(dim_mlp, args.pred_hidden_dim, 4)
            elif self.args.stop_gradient2:
                self.predictor = PredictionMLP(dim_mlp, args.pred_hidden_dim, 4)
            
            if not args.eqv_loss_type == 'shift':
                if args.R_choices == 'linear':
                    self.equi_transform = EquiRotate(dim_mlp, self.args.use_mlp)
                elif args.R_choices == 'film':
                    self.equi_transform = FiLMAlign(args, dim_mlp, dim_mlp*2)
                elif args.R_choices == 'cond_align':
                    self.equi_transform = CondAlignMLP(args, dim_mlp, embed_dim=16, hidden_dim=dim_mlp*2)
                elif args.R_choices == 'multi_linear_align':
                    self.equi_transform = MultiLinearAlign(self.args, dim_mlp)
                elif args.R_choices == 'affine_coupling':
                    self.equi_transform = InvertibleNet(args, dim_mlp, dim_mlp*2, 2)
                elif args.R_choices == 'warp_align':
                    self.equi_transform = WarpAlign(dim_mlp, dim_mlp*2)
        
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient
        
        if self.gie:
            for param_q, param_k in zip(
                self.projector_q.parameters(), self.projector_k.parameters()
            ):
                param_k.data.copy_(param_q.data)  # initialize
                param_k.requires_grad = False  # not update by gradient
                
        if self.gie and self.args.stop_gradient1:
            for param_q, param_k in zip(
                self.predictor_q.parameters(), self.predictor_k.parameters()
            ):
                param_k.data.copy_(param_q.data)  # initialize
                param_k.requires_grad = False  # not update by gradient
        
        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        if args.connector=='softmax':
            self.connector = torch.nn.Softmax(dim=1)
        elif args.connector=='shift':
            self.connector = None
            permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
            self.permute_tensor = torch.tensor(permute_patterns).cuda()
            
    @torch.no_grad()
    def _momentum_update_key_encoder(self) -> None:
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

        if self.gie:
            for param_q, param_k in zip(
                self.projector_q.parameters(), self.projector_k.parameters()
            ):
                param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

        if self.gie and self.args.stop_gradient1:
            for param_q, param_k in zip(
                self.predictor_q.parameters(), self.predictor_k.parameters()
            ):
                param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
        
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys) -> None:
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def extract_guided_output(self, FX, qk='q'):
        b,c = FX.shape
        
        if self.args.stop_gradient1:
            if qk == 'q':
                eqv_logit = self.predictor_q(FX).flatten(1)
            elif qk == 'k':
                eqv_logit = self.predictor_k(FX).flatten(1)
        elif self.args.stop_gradient2:
            eqv_logit = self.predictor(FX).flatten(1)
        
        if self.connector:
            eqv_score = self.connector(eqv_logit)
            
            if self.args.eqv_loss_type == 'shift':
                FX_re = FX.reshape([b, c//self.order, self.order])    
                permuted_reprs = [torch.roll(FX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
                permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            else:
                if self.args.R_choices == 'linear' or self.args.R_choices == 'affine_coupling' or self.args.R_choices == 'warp_align':
                    permuted_reprs = [FX]
                    for _ in range(self.order-1):
                        permuted_reprs.append(self.equi_transform(permuted_reprs[-1]))

                    order = [0, 3, 2, 1]
                    permuted_reprs = [permuted_reprs[i] for i in order]
                    permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                    HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                
                elif self.args.R_choices == 'film' or self.args.R_choices == 'cond_align' or self.args.R_choices == 'multi_linear_align':
                    if self.args.multi_num == 3:
                        permuted_reprs = [FX]
                        
                        order = [
                            torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                        ]
                        
                        for i in range(self.order-1):
                            permuted_reprs.append(self.equi_transform(FX, order[i]))
                        permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                        HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                    
                    else:
                        permuted_reprs = []
                        
                        order = [
                            torch.full((FX.size()[0],), 0, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                        ]
                        for i in range(self.order):
                            permuted_reprs.append(self.equi_transform(FX, order[i]))
                        permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                        HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
        
        else:
            eqv_idx = torch.argmax(eqv_logit, dim=1)
            if self.eqv_loss_type == 'shift':
                batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
                FX_re = FX.reshape([b, c//self.order, self.order])
                FX_re = FX_re.gather(2, batch_perm)
                HX = FX_re.reshape([b,c])
            else:
                trans = (4 - eqv_idx) % self.order
                        
                if self.args.R_choices == 'linear' or self.args.R_choices == 'affine_coupling' or self.args.R_choices == 'warp_align':  
                    FX_all = [FX]
                    for _ in range(self.order-1):
                        FX_all.append(self.equi_transform(FX_all[-1]))
                    FX_stack = torch.stack(FX_all, dim=1)
                    trans = trans.view(-1, 1, 1).expand(-1, -1, FX.size()[-1])
                    HX = torch.gather(FX_stack, dim=1, index=trans).squeeze(1)
                
                elif self.args.R_choices == 'film' or self.args.R_choices == 'cond_align' or self.args.R_choices == 'multi_linear_align':
                    HX = self.equi_transform(FX, trans)
        
        return eqv_logit, HX
        

    def forward(self, im_q, im_k, r1, r2):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        if not self.gie:
            q = self.encoder_q(im_q)  # queries: NxC
            q = nn.functional.normalize(q, dim=1)

            # compute key features
            with torch.no_grad():  # no gradient to keys
                self._momentum_update_key_encoder()  # update the key encoder

                # shuffle for making use of BN
                im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

                k = self.encoder_k(im_k)  # keys: NxC
                k = nn.functional.normalize(k, dim=1)

                # undo shuffle
                k = self._batch_unshuffle_ddp(k, idx_unshuffle)
            
            eqv_loss = torch.tensor(0)
            pred_loss = torch.tensor(0)
        
        else:
            FX_q = self.encoder_q(im_q)
            eqv_logit_q, HX_q = self.extract_guided_output(FX_q, qk='q')
            q = self.projector_q(HX_q)
            q = nn.functional.normalize(q, dim=1)
            
            if self.args.stop_gradient1:
                with torch.no_grad():
                    self._momentum_update_key_encoder()
                    im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
                    
                    FX_k = self.encoder_k(im_k)
                    eqv_logit_k, HX_k = self.extract_guided_output(FX_k, qk='k')
                    k = self.projector_k(HX_k)
                    k = nn.functional.normalize(k, dim=1)
                    
                    FX_k = self._batch_unshuffle_ddp(FX_k, idx_unshuffle)
                    k = self._batch_unshuffle_ddp(k, idx_unshuffle)
                    eqv_logit_k = self._batch_unshuffle_ddp(eqv_logit_k, idx_unshuffle)
            
            elif self.args.stop_gradient2:
                with torch.no_grad():
                    self._momentum_update_key_encoder()
                    im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
                    
                    FX_k = self.encoder_k(im_k)
                    eqv_logit_k, HX_k = self.extract_guided_output(FX_k)
                    k = self.projector_k(HX_k)
                    k = nn.functional.normalize(k, dim=1)
                    
                    FX_k = self._batch_unshuffle_ddp(FX_k, idx_unshuffle)
                    k = self._batch_unshuffle_ddp(k, idx_unshuffle)
                    eqv_logit_k = self._batch_unshuffle_ddp(eqv_logit_k, idx_unshuffle)
            
            if not self.args.eqv_loss_type == 'shift':
                trans = (r2 - r1) % self.order
                if self.args.R_choices == 'linear' or self.args.R_choices == 'warp_align':
                    B, D = FX_q.size()
                    
                    FX_q_all = [FX_q]
                    for _ in range(self.order-1):
                        FX_q_all.append(self.equi_transform(FX_q_all[-1]))
                    
                    FX_q_stack = torch.stack(FX_q_all, dim=1)
                    trans = trans.view(-1, 1, 1).expand(-1, 1, D)
                    FX_k_pred = torch.gather(FX_q_stack, dim=1, index=trans).squeeze(1)
                
                elif self.args.R_choices == 'film' or self.args.R_choices == 'cond_align' or self.args.R_choices == 'multi_linear_align':
                    FX_k_pred = self.equi_transform(FX_q, trans)
                
                elif self.args.R_choices == 'affine_coupling':
                    B, D = FX_q.size()
                    
                    FX_q_all = [FX_q]
                    for _ in range(self.order-1):
                        FX_q_all.append(self.equi_transform(FX_q_all[-1]))
                    
                    FX_q_stack = torch.stack(FX_q_all, dim=1)
                    trans = trans.view(-1, 1, 1).expand(-1, 1, D)
                    FX_k_pred = torch.gather(FX_q_stack, dim=1, index=trans).squeeze(1)
                    
                    if self.args.reverse:
                        FX_k_all = [FX_k]
                        for _ in range(self.order-1):
                            FX_k_all.append(self.equi_transform(FX_k_all[-1], reverse=True))
                        
                        FX_k_stack = torch.stack(FX_k_all, dim=1)
                        FX_q_pred = torch.gather(FX_k_stack, dim=1, index=trans).squeeze(1)
                
                if self.args.eqv_loss_type == 'mse':
                    if self.args.reverse:
                        eqv_loss = F.mse_loss(FX_k, FX_k_pred) / 2 + F.mse_loss(FX_q, FX_q_pred) / 2
                    else:
                        eqv_loss = F.mse_loss(FX_k, FX_k_pred)
                elif self.args.eqv_loss_type == 'infonce':
                            if self.args.reverse:
                                eqv_loss = infoNCE(FX_k, FX_k_pred) / 4 + infoNCE(FX_k_pred, FX_k) / 4 + infoNCE(FX_q, FX_q_pred) / 4 + infoNCE(FX_q_pred, FX_q) / 4
                            else:
                                eqv_loss = infoNCE(FX_k, FX_k_pred) / 2 + infoNCE(FX_k_pred, FX_k) / 2
            else:
                eqv_loss = rotation_equivariance_loss(FX_q, FX_k, r1, r2)
            pred_loss = F.cross_entropy(eqv_logit_q, r1) / 2 + F.cross_entropy(eqv_logit_k, r2) / 2
        
        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels, eqv_loss, pred_loss


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

class GatherLayer(torch.autograd.Function):
    """
    Gather tensors from all workers with support for backward propagation:
    This implementation does not cut the gradients as torch.distributed.all_gather does.
    """

    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]
    
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
    """
    Similar to classy_vision.generic.distributed_util.gather_from_all
    except that it does not cut the gradients
    """
    if tensor.ndim == 0:
        # 0 dim tensors cannot be gathered. so unsqueeze
        tensor = tensor.unsqueeze(0)

    if is_distributed_training_run():
        tensor, orig_device = convert_to_distributed_tensor(tensor)
        gathered_tensors = GatherLayer.apply(tensor)
        gathered_tensors = [
            convert_to_normal_tensor(_tensor, orig_device)
            for _tensor in gathered_tensors
        ]
    else:
        gathered_tensors = [tensor]
    gathered_tensor = torch.cat(gathered_tensors, 0)
    return gathered_tensor

def infoNCE(nn, p, temperature=0.2):

    nn = torch.nn.functional.normalize(nn, dim=1)
    p = torch.nn.functional.normalize(p, dim=1)
    nn = gather_from_all(nn)
    p = gather_from_all(p)
    logits = nn @ p.T
    logits /= temperature
    n = p.shape[0]
    labels = torch.arange(0, n, dtype=torch.long).cuda()
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss

def rotation_equivariance_loss(FX1, FX2, r1, r2, order=4):
    """
    FX1, FX2: (B, C) - features from rotated versions of the same image
    r1, r2: (B,) - rotation labels (0~order-1)
    """
    B, C = FX1.shape
    FX1_re = FX1.reshape(B, C // order, order)  # shape: (B, C', order)
    FX2_re = FX2.reshape(B, C // order, order)

    # Calculate relative rotation difference for each sample
    r_diff = (r2 - r1)  # shape: (B,)

    # Shift FX1 representations by r2 - r1 steps (torch.roll accepts negative values)
    shifted_FX1 = torch.stack([
        torch.roll(FX1_re[i], shifts=int(r_diff[i].item()), dims=1)
        for i in range(B)
    ], dim=0)  # shape: (B, C', order)

    shifted_FX1 = shifted_FX1.reshape(B, C)

    # Compute MSE loss between shifted FX1 and FX2
    loss = F.mse_loss(shifted_FX1, FX2)
    return loss