import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AffineCoupling(nn.Module):
    """
    RealNVP-style affine coupling layer.
    Args
    ----
    dim : int
        Feature dimension (must be even).
    hidden_dim : int
        Hidden layer width for the scale/shift network.
    mask_flip : bool
        If True, swap the two channel halves (for alternating masks).
    scale_clip : float
        Clamp value for scale output before exp; helps numeric stability.
    """
    def __init__(self, args, dim: int, hidden_dim: int, mask_flip: bool = False,
                 scale_clip: float = 2.0):
        super().__init__()
        assert dim % 2 == 0, "dim must be even"
        self.dim = dim
        if args.scale_embedding:
            self.scale_clip = nn.Parameter(torch.ones(dim // 2) * 0.1)
        else:
            self.scale_clip = scale_clip
        self.mask_flip = mask_flip

        self.net = nn.Sequential(
            nn.Linear(dim // 2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, dim)  # outputs [s, t]
        )

        # initialize last layer weights near zero for stable start
        self._init_rotation_like()

    def _make_block_rotation(self, angle_deg: int, dim: int) -> torch.Tensor:
        angle_rad = math.radians(angle_deg)
        cos_a, sin_a = math.cos(angle_rad), math.sin(angle_rad)
        rot2d = torch.tensor([[cos_a, -sin_a], [sin_a, cos_a]])

        num_blocks = dim // 2
        blocks = [rot2d for _ in range(num_blocks)]
        return torch.block_diag(*blocks)

    def _init_rotation_like(self):
        with torch.no_grad():
            first_linear = self.net[0]
            base_rotation = self._make_block_rotation(angle_deg=90, dim=first_linear.in_features)
            repeat_factor = first_linear.out_features // first_linear.in_features
            repeated_rot = base_rotation.repeat_interleave(repeat_factor, dim=0)
            noise1 = 0.01 * torch.randn_like(first_linear.weight)
            first_linear.weight.copy_(repeated_rot + noise1)
            first_linear.bias.zero_()

            second_linear = self.net[-1]
            second_linear.weight.normal_(mean=0.0, std=0.01)
            second_linear.bias.zero_()
            
    def forward(self, x, reverse: bool = False, compute_ldj: bool = False):
        """
        Parameters
        ----------
        x : Tensor
            [..., dim] input tensor.
        reverse : bool
            If True, applies the inverse transformation.
        compute_ldj : bool
            If True, returns (y, ldj) where ldj is the log-det Jacobian.

        Returns
        -------
        Tensor or (Tensor, Tensor)
            y (and optionally ldj).
        """
        if self.mask_flip:
            x2, x1 = x.chunk(2, dim=-1)
        else:
            x1, x2 = x.chunk(2, dim=-1)

        st = self.net(x1)
        s, t = st.chunk(2, dim=-1)
        s = torch.tanh(s) * self.scale_clip  # clamp scale

        if reverse:
            x2 = (x2 - t) * torch.exp(-s)
            ldj = -s.sum(dim=-1)
        else:
            x2 = x2 * torch.exp(s) + t
            ldj = s.sum(dim=-1)

        y = torch.cat([x1, x2], dim=-1) if not self.mask_flip else torch.cat([x2, x1], dim=-1)
        if compute_ldj:
            return y, ldj
        return y
    
class InvertibleNet(nn.Module):
    """
    Stack of alternating affine coupling layers.
    """
    def __init__(self, args, dim: int, hidden_dim: int, depth: int = 4):
        super().__init__()
        layers = []
        for i in range(depth):
            if args.reverse:
                layers.append(AffineCoupling(args, dim, hidden_dim, mask_flip=(i % 2 == 1), scale_clip=2.0))
            else:
                layers.append(AffineCoupling(args, dim, hidden_dim, mask_flip=(i % 2 == 1)))
        self.layers = nn.ModuleList(layers)

    def forward(self, x, reverse: bool = False, compute_ldj: bool = False):
        ldj_total = 0.0
        layers = reversed(self.layers) if reverse else self.layers
        for layer in layers:
            if compute_ldj:
                x, ldj = layer(x, reverse=reverse, compute_ldj=True)
                ldj_total = ldj_total + ldj
            else:
                x = layer(x, reverse=reverse)
        if compute_ldj:
            return x, ldj_total
        return x
    
class FiLMAlign(nn.Module):
    def __init__(self, args, feat_dim: int, hidden_dim: int = 128):
        super().__init__()
        # (r_diff ∈ {0,1,2,3}) → embedding → γ,β 생성
        self.args = args
        self.multi_num = self.args.multi_num
        self.embed  = nn.Embedding(num_embeddings=self.multi_num, embedding_dim=hidden_dim)
        self.mlp_g  = nn.Sequential(
            nn.Linear(hidden_dim, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, feat_dim)
        )
        self.mlp_b  = nn.Sequential(
            nn.Linear(hidden_dim, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, feat_dim)
        )
        # 기본 FILM 버전: x → x' = γ⊙x + β
        # (여기선 scale, shift 둘 다 학습)
        
    def forward(self, x: torch.Tensor, r_diff: torch.Tensor) -> torch.Tensor:
        """
        x:      [B, D]   (예: FX1)
        r_diff: [B] (각 값 ∈ {0,1,2,3})
        
        returns: x_aligned ∈ [B, D]  (예: FX1 → FX2_pred)
        """
        h = self.embed(r_diff)         # [B, hidden_dim]
        gamma = self.mlp_g(h)              # [B, D]
        beta = self.mlp_b(h)              # [B, D]
        return gamma * x + beta               # [B, D]
    
class CondAlignMLP(nn.Module):
    def __init__(self, args, feat_dim: int, embed_dim: int = 16, hidden_dim: int = 256):
        super().__init__()
        # rotation (0~3) → embed_dim 차원 임베딩
        self.args = args
        self.multi_num = self.args.multi_num
        self.embed = nn.Embedding(num_embeddings=4, embedding_dim=embed_dim)
        
        # 입력: [x₁ ∈ ℝᴰ ; rot_embed ∈ ℝᴱ] → 히든 → 출력: ℝᴰ
        self.net = nn.Sequential(
            nn.Linear(feat_dim + embed_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, feat_dim)
        )
        # 마지막 레이어는 bias=0 / weight 소량의 noise 초기화 가능
        nn.init.normal_(self.net[-1].weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.net[-1].bias)
        
    def forward(self, x: torch.Tensor, r_diff: torch.Tensor) -> torch.Tensor:
        """
        x: [B, D] ; r_diff: [B] (값 ∈ {0,1,2,3})
        """
        h = self.embed(r_diff)           # [B, embed_dim]
        inp = torch.cat([x, h], dim=1)   # [B, D+embed_dim]
        return self.net(inp)             # [B, D]  ≈ x₂

class RotationDeepMLP(nn.Module):
    def __init__(self, dim=2048, hidden_dim=512, depth=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, dim)
            ))

        self._init_weights()
    
    def _init_weights(self):
        with torch.no_grad():
            for i, block in enumerate(self.layers):
                lin1 = block[1]
                lin2 = block[3]
                
                if i == 0:
                    eye = torch.eye(lin1.in_features)
                    noise = 0.01 * torch.randn_like(eye)
                    lin1.weight.copy_(eye + noise)
                else:
                    nn.init.kaiming_uniform_(lin1.weight)
                nn.init.zeros_(lin1.bias)

                nn.init.zeros_(lin2.weight)
                nn.init.zeros_(lin2.bias)
                
    def forward(self, x):
        for block in self.layers:
            x = x + block(x)
        return x


class MultiLinearAlign(nn.Module):
    def __init__(self, args, feat_dim: int):
        super().__init__()
        self.args = args
        self.multi_num = self.args.multi_num
        
        if not self.args.deep:
            self.Ws = nn.ParameterList([
                nn.Parameter(torch.eye(feat_dim) + 0.01*torch.randn(feat_dim, feat_dim))
                for _ in range(self.multi_num)
            ])
            self.bs = nn.ParameterList([
                nn.Parameter(torch.zeros(feat_dim)) for _ in range(self.multi_num)
            ])
        else:
            self.blocks = nn.ModuleList()
            for _ in range(self.multi_num):
                self.blocks.append(RotationDeepMLP(feat_dim, feat_dim, depth=args.depth))

    def forward(self, x: torch.Tensor, r_diff: torch.Tensor) -> torch.Tensor:
        B, D = x.size()
        out = torch.empty_like(x)
        if not self.args.deep:
            if self.multi_num == 4:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]

                        Wk = self.Ws[k].to(x.dtype)
                        bk = self.bs[k].to(x.dtype)
                        out[idx_k] = xk @ Wk.T + bk
            else:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        
                        if k == 0:
                            out[idx_k] = xk
                        else:
                            Wk = self.Ws[k-1].to(x.dtype)
                            bk = self.bs[k-1].to(x.dtype)
                            out[idx_k] = xk @ Wk.T + bk
        
        else:
            if self.multi_num == 4:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        xk = self.blocks[k](xk)
                        out[idx_k] = xk
            else:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        
                        if k == 0:
                            out[idx_k] = xk
                        else:
                            xk = self.blocks[k-1](xk)
                            out[idx_k] = xk
                
        return out
    
class WarpAlign(nn.Module):
    def __init__(self, feat_dim: int, hidden: int = 256):
        super().__init__()
        # x₁ → delta ∈ ℝᴰ, x₁ + delta = x₂_aligned
        self.net = nn.Sequential(
            nn.Linear(feat_dim, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, feat_dim)
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        delta = self.net(x)   # [B, D]
        return x + delta      # [B, D]
