import torch
import torch.nn as nn
import torch.nn.functional as F
    
class EquiRotate(nn.Module):
    def __init__(self, equi_repr_dim: int, use_mlp: bool):
        super().__init__()
        self.use_mlp = use_mlp
        self.equi_repr_dim = equi_repr_dim
        
        if self.use_mlp:
            hidden_dim = equi_repr_dim * 2
            self.net = nn.Sequential(
                nn.Linear(equi_repr_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, equi_repr_dim)
            )
            self._init_two_mlp()

        else:
            self.net = nn.Linear(equi_repr_dim, equi_repr_dim)
            self._init_identity_plus_noise()
        
    def _init_identity_plus_noise(self):
        with torch.no_grad():
            eye = torch.eye(self.equi_repr_dim)
            noise = 0.01 * torch.randn_like(self.net.weight)
            self.net.weight.copy_(eye + noise)
            self.net.bias.zero_()

    def _init_two_mlp(self):
        with torch.no_grad():
            first_linear = self.net[0]
            eye = torch.eye(self.equi_repr_dim)
            repeated_eye = eye.repeat_interleave(first_linear.out_features // self.equi_repr_dim, dim=0)
            noise1 = 0.01 * torch.randn_like(first_linear.weight)
            first_linear.weight.copy_(repeated_eye + noise1)
            first_linear.bias.zero_()

            second_linear = self.net[2]
            second_linear.weight.normal_(mean=0.0, std=0.01)
            second_linear.bias.zero_()
        
    def forward(self, r: torch.Tensor):

        out = self.net(r)
        
        return out
    
class RotationDeepMLP(nn.Module):
    def __init__(self, dim=2048, hidden_dim=512, depth=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, dim)
            ))

        self._init_weights()
    
    def _init_weights(self):
        with torch.no_grad():
            for i, block in enumerate(self.layers):
                lin1 = block[1]
                lin2 = block[3]
                
                if i == 0:
                    eye = torch.eye(lin1.in_features)
                    noise = 0.01 * torch.randn_like(eye)
                    lin1.weight.copy_(eye + noise)
                else:
                    nn.init.kaiming_uniform_(lin1.weight)
                nn.init.zeros_(lin1.bias)

                nn.init.zeros_(lin2.weight)
                nn.init.zeros_(lin2.bias)
                
    def forward(self, x):
        for block in self.layers:
            x = x + block(x)
        return x
     
class MultiLinearAlign(nn.Module):
    def __init__(self, args, feat_dim: int):
        super().__init__()
        self.args = args
        self.multi_num = self.args.multi_num
        
        if not self.args.deep:
            self.Ws = nn.ParameterList([
                nn.Parameter(torch.eye(feat_dim) + 0.01*torch.randn(feat_dim, feat_dim))
                for _ in range(self.multi_num)
            ])
            self.bs = nn.ParameterList([
                nn.Parameter(torch.zeros(feat_dim)) for _ in range(self.multi_num)
            ])
        else:
            self.blocks = nn.ModuleList()
            for _ in range(self.multi_num):
                self.blocks.append(RotationDeepMLP(feat_dim, feat_dim, depth=args.depth))

    def forward(self, x: torch.Tensor, r_diff: torch.Tensor) -> torch.Tensor:
        B, D = x.size()
        out = torch.empty_like(x)
        if not self.args.deep:
            if self.multi_num == 4:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]

                        Wk = self.Ws[k].to(x.dtype)
                        bk = self.bs[k].to(x.dtype)
                        out[idx_k] = xk @ Wk.T + bk
            else:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        
                        if k == 0:
                            out[idx_k] = xk
                        else:
                            Wk = self.Ws[k-1].to(x.dtype)
                            bk = self.bs[k-1].to(x.dtype)
                            out[idx_k] = xk @ Wk.T + bk
        
        else:
            if self.multi_num == 4:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        xk = self.blocks[k](xk)
                        out[idx_k] = xk
            else:
                for k in range(4):
                    idx_k = (r_diff == k)
                    if idx_k.any():
                        xk = x[idx_k]
                        
                        if k == 0:
                            out[idx_k] = xk
                        else:
                            xk = self.blocks[k-1](xk)
                            out[idx_k] = xk
                
        return out