import torch
import torchvision
import numpy as np
from torch import nn
from utils import gather_from_all
import torch.nn.functional as F
from models import EquiRotate, MultiLinearAlign
from lightcnn import LightCNN

class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, pretrain_set):
        super().__init__()
        if 'EMNIST' in pretrain_set or 'AffNIST' in pretrain_set:
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, out_dim)
            )
        else:
            self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, out_dim, bias=False),
                                    nn.BatchNorm1d(out_dim)
                                    )
            
    def forward(self, x):
        return self.net(x)
    

class PredictionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, pretrain_set):
        super().__init__()
        if 'EMNIST' in pretrain_set or 'AffNIST' in pretrain_set:
            self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, out_dim))
        else:
            self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True), 
                                    nn.Linear(hidden_dim, hidden_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True), 
                                    nn.Linear(hidden_dim, out_dim))
    def forward(self, x):
        return self.net(x)
        

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x
    
class GuidedSimCLR(nn.Module):

    def __init__(self, args):
        super().__init__()
        
        self.order = 4
        self.gie = args.gie
        self.var_loss = args.var_loss
        self.args = args

        if args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512

        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048
        
        elif args.arch=='lightcnn':
            self.backbone = LightCNN(feat_dim=args.lightcnn_feat_dim)
            self.backbone.fc = nn.Identity()
            feature_dim = args.lightcnn_feat_dim

        if self.gie:
            self.predictor = PredictionMLP(feature_dim, args.pred_hidden_dim, 4, args.pretrain_set)

            if not args.eqv_loss_type == 'shift':
                if args.R_choices == 'linear':
                    self.equi_transform = EquiRotate(feature_dim, self.args.use_mlp)
                elif args.R_choices == 'multi_linear_align':
                    self.equi_transform = MultiLinearAlign(args, feature_dim)
        
        if 'EMNIST' in args.pretrain_set or 'AffNIST' in args.pretrain_set:
            self.projector = ProjectionMLP(feature_dim, 256, 128, args.pretrain_set)
        else:
            self.projector = ProjectionMLP(feature_dim, 2048, 128, args.pretrain_set)
        
        if args.connector=='softmax':
            self.connector = torch.nn.Softmax(dim=1)
        elif args.connector=='shift':
            self.connector = None
            permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
            self.permute_tensor = torch.tensor(permute_patterns).cuda()

        self.no_ori_loss = args.no_ori_loss
        self.inv_score = args.inv_score

    def extract_guided_output(self, FX):

        eqv_logit = self.predictor(FX).flatten(1)
        b,c = FX.shape
        if self.connector:
            eqv_score = self.connector(eqv_logit)
            
            if self.args.eqv_loss_type == 'shift':
                FX_re = FX.reshape([b, c//self.order, self.order])    
                permuted_reprs = [torch.roll(FX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
                permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            else:
                if self.args.R_choices == 'linear':
                    permuted_reprs = [FX]
                    for _ in range(self.order-1):
                        permuted_reprs.append(self.equi_transform(permuted_reprs[-1]))

                    order = [0, 3, 2, 1]
                    permuted_reprs = [permuted_reprs[i] for i in order]
                    permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                    HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                
                elif self.args.R_choices == 'multi_linear_align':
                    if self.args.multi_num == 3:
                        permuted_reprs = [FX]
                        
                        order = [
                            torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                        ]
                        
                        for i in range(self.order-1):
                            permuted_reprs.append(self.equi_transform(FX, order[i]))
                        permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                        HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                    else:
                        permuted_reprs = []
                        
                        order = [
                            torch.full((FX.size()[0],), 0, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                            torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                        ]
                        for i in range(self.order):
                            permuted_reprs.append(self.equi_transform(FX, order[i]))
                        permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                        HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                        
        else:
            eqv_idx = torch.argmax(eqv_logit, dim=1)
            if self.args.eqv_loss_type == 'shift':
                batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
                FX_re = FX.reshape([b, c//self.order, self.order])
                FX_re = FX_re.gather(2, batch_perm)
                HX = FX_re.reshape([b,c])
            else:
                trans = (4 - eqv_idx) % self.order
                
                if self.args.R_choices == 'linear':  
                    FX_all = [FX]
                    for _ in range(self.order-1):
                        FX_all.append(self.equi_transform(FX_all[-1]))
                    FX_stack = torch.stack(FX_all, dim=1)
                    trans = trans.view(-1, 1, 1).expand(-1, -1, FX.size()[-1])
                    HX = torch.gather(FX_stack, dim=1, index=trans).squeeze(1)
                
                elif self.args.R_choices == 'multi_linear_align':
                    HX = self.equi_transform(FX, trans)
                            
        out = self.projector(HX)
        return eqv_logit, out


    def forward(self, x1, x2, r1, r2, alpha, beta):
        FX1 = self.backbone(x1)
        FX2 = self.backbone(x2)                
        
        if self.gie:
            eqv_logit1, out1 = self.extract_guided_output(FX1)
            eqv_logit2, out2 = self.extract_guided_output(FX2)
            
            if not self.args.eqv_loss_type == 'shift':
                trans = (r2 - r1) % self.order
                if self.args.R_choices == 'linear':
                    B, D = FX1.size()
                    
                    FX1_all = [FX1]
                    for _ in range(self.order-1):
                        FX1_all.append(self.equi_transform(FX1_all[-1]))
                    
                    FX1_stack = torch.stack(FX1_all, dim=1)
                    trans = trans.view(-1, 1, 1).expand(-1, 1, D)
                    FX2_pred = torch.gather(FX1_stack, dim=1, index=trans).squeeze(1)
                
                elif self.args.R_choices == 'multi_linear_align':
                    FX2_pred = self.equi_transform(FX1, trans)
                
                if self.args.eqv_loss_type == 'mse':
                    eqv_loss = F.mse_loss(FX2, FX2_pred)
                elif self.args.eqv_loss_type == 'infonce':
                    eqv_loss = infoNCE(FX2, FX2_pred) / 2 + infoNCE(FX2_pred, FX2) / 2
                    
            else:
                eqv_loss = rotation_equivariance_loss(FX1, FX2, r1, r2)
                
            pred_loss = F.cross_entropy(eqv_logit1, r1) / 2 + F.cross_entropy(eqv_logit2, r2) / 2
            
        else:
            out1 = self.projector(FX1)
            out2 = self.projector(FX2)

        con_loss = infoNCE(out1, out2) / 2 + infoNCE(out2, out1) / 2  

        loss = con_loss
        
        if self.gie:
            loss = loss  + alpha * eqv_loss + beta * pred_loss
        else:
            eqv_loss = torch.tensor(0)
            pred_loss = torch.tensor(0)

        B,C = FX1.shape
        FX1_re = FX1.reshape(B, C // self.order, self.order)
        # print(FX1_re[0][0])
        variance_across_order = FX1_re.var(dim=2)
        mean_variance = variance_across_order.mean()
        if self.var_loss:
            variance_gap = torch.clamp(0.1 - mean_variance, min=0.0)
            # sqrt 연산 후, 최소값 0.001을 적용 (tensor 연산)
            var_loss = torch.maximum(torch.tensor(0.001, device=variance_gap.device), torch.sqrt(variance_gap))
            loss += self.var_loss * var_loss

        return loss, con_loss, eqv_loss, pred_loss, mean_variance


def infoNCE(nn, p, temperature=0.2):

    nn = torch.nn.functional.normalize(nn, dim=1)
    p = torch.nn.functional.normalize(p, dim=1)
    nn = gather_from_all(nn)
    p = gather_from_all(p)
    logits = nn @ p.T
    logits /= temperature
    n = p.shape[0]
    labels = torch.arange(0, n, dtype=torch.long).cuda()
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss


def rotation_equivariance_loss(FX1, FX2, r1, r2, order=4):
    """
    FX1, FX2: (B, C) - features from rotated versions of the same image
    r1, r2: (B,) - rotation labels (0~order-1)
    """
    B, C = FX1.shape
    FX1_re = FX1.reshape(B, C // order, order)  # shape: (B, C', order)
    FX2_re = FX2.reshape(B, C // order, order)

    # Calculate relative rotation difference for each sample
    r_diff = (r2 - r1)  # shape: (B,)

    # Shift FX1 representations by r2 - r1 steps (torch.roll accepts negative values)
    shifted_FX1 = torch.stack([
        torch.roll(FX1_re[i], shifts=int(r_diff[i].item()), dims=1)
        for i in range(B)
    ], dim=0)  # shape: (B, C', order)

    shifted_FX1 = shifted_FX1.reshape(B, C)

    # Compute MSE loss between shifted FX1 and FX2
    loss = F.mse_loss(shifted_FX1, FX2)
    return loss