# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
import torch.nn as nn

class HyperNet(nn.Module):
    def __init__(self, latent_size : int, output_size : int, args):
        super(HyperNet,self).__init__()
        if args.bias_hypernet:
            print("Bias in the hypernetwork")
        if args.hypernetwork == "linear":
            self.net = nn.Sequential(
                nn.Linear(latent_size,output_size,bias=args.bias_hypernet), # Linear combination for now
            )
        else:
            self.net = nn.Sequential(
                nn.Linear(latent_size,latent_size,bias=args.bias_hypernet),
                nn.ReLU(),
                nn.Linear(latent_size,latent_size,bias=args.bias_hypernet),
                nn.ReLU(),
                nn.Linear(latent_size,output_size,bias=args.bias_hypernet),
            )
        
    def forward(self, x : torch.Tensor):
        out = self.net(x)
        return out

    
class ParametrizedNet(nn.Module):
    def __init__(self,equivariant_size : int, latent_size : int, args):
        super(ParametrizedNet,self).__init__()
        if args.predictor == "":
            archi_str = str(equivariant_size) + "-" + str(equivariant_size)
        else:
            archi_str = str(equivariant_size) + "-"+ args.predictor +"-" + str(equivariant_size)
        print("Predictor architecture: ", archi_str)
        self.predictor = [int(x) for x in archi_str.split("-")]
        self.args = args
        
        self.num_weights_each = [ self.predictor[i]*self.predictor[i+1] for i in range(len(self.predictor)-1)]

        if self.args.bias_pred:
            self.num_biases_each = [self.predictor[i+1] for i in range(len(self.predictor)-1)]
            self.num_params_each = [self.num_weights_each[i] + self.num_biases_each[i] for i in range(len(self.num_biases_each))]
        else:
            self.num_params_each = self.num_weights_each
        print(self.num_params_each)
        self.cum_params = [0] + list(np.cumsum(self.num_params_each))        
        self.hypernet = HyperNet(latent_size, self.cum_params[-1], self.args)
        self.activation = nn.ReLU() if args.predictor_relu else nn.Identity()
        
    def forward(self, x : torch.Tensor, z : torch.Tensor):
        """
         x must be (batch_size, 1, size)
        
         Since F.linear(x,A,b) = x @ A.T + b (to have A (out_dim,in_dim) and be coherent with nn.linear)
         and  torch.bmm(x,A)_i = x_i @ A_i
         to emulate the same behaviour, we transpose A along the last two axes before bmm
        """
        weights = self.hypernet(z)
        out=x
        for i in range(len(self.predictor)-1):
            w = weights[...,self.cum_params[i]:self.cum_params[i] + self.num_weights_each[i]].view(-1,self.predictor[i+1],self.predictor[i])
            out = torch.bmm(out,torch.transpose(w,-2,-1))
            if self.args.bias_pred:
                b = weights[...,self.cum_params[i+1] - self.num_biases_each[i]:self.cum_params[i+1]].unsqueeze(1)
                out = out + b
            if i < len(self.predictor)-2:
                out = self.activation(out)
        
        return out.squeeze()

class MLPPredictor(nn.Module):
    def __init__(self,repr_dim=512,latent_dim=4,n_layers=2, output_activation=nn.Identity()):
        super(MLPPredictor, self).__init__()
        
        self.repr_dim = repr_dim
        self.latent_dim = latent_dim
        self.first_proj = [nn.Linear(self.repr_dim+self.latent_dim,self.repr_dim)] if n_layers == 1 else [nn.Linear(self.repr_dim+self.latent_dim,self.repr_dim), nn.ReLU()]
        self.layers = []
        for i in range(n_layers-1):
            self.layers.append(nn.Linear(self.repr_dim,self.repr_dim))
            if i < n_layers-2:
                self.layers.append(nn.ReLU())
        self.pred = nn.Sequential(*(self.first_proj+self.layers))
        print(self.pred)
        self.output_activation = output_activation
    
    def forward(self, representation,latent=None):
        if latent is not None:
            out = torch.concat((latent,representation),dim=1)
        else:
            out = representation
        out = self.pred(out)
        out = self.output_activation(out)
        return out
    
class EquiTrans(nn.Module):
    """
    Hypernetwork-style transformation module.
    """
    def __init__(self, equi_repr_dim: int, trans_repr_dim: int):
        super().__init__()
        self.equitrans_layers = [equi_repr_dim, equi_repr_dim]

        # Calculate parameters needed for each block
        self.num_weights_per_block = [
            self.equitrans_layers[i] * self.equitrans_layers[i + 1]
            for i in range(len(self.equitrans_layers) - 1)
        ]
        self.cumulative_params = [0] + list(np.cumsum(self.num_weights_per_block))

        # Hypernetwork to generate weights
        self.hypernet = nn.Linear(trans_repr_dim, self.cumulative_params[-1], bias=False)

    def forward(self, r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Apply transformation predicted by `t` on representation `r`.
        """
        all_weights = self.hypernet(t)  # shape: [B, total_params]
        output = r.unsqueeze(1)

        # Sequentially apply linear blocks
        for i in range(len(self.equitrans_layers) - 1):
            start_idx = self.cumulative_params[i]
            end_idx = start_idx + self.num_weights_per_block[i]

            w_block = all_weights[..., start_idx:end_idx]
            w_reshaped = w_block.view(-1,
                                      self.equitrans_layers[i + 1],
                                      self.equitrans_layers[i])
            output = torch.bmm(output, w_reshaped.transpose(-2, -1))

        return output.squeeze()
    
class EquiRotate(nn.Module):
    def __init__(self, equi_repr_dim: int, use_mlp: bool):
        super().__init__()
        self.use_mlp = use_mlp
        self.equi_repr_dim = equi_repr_dim
        
        if self.use_mlp:
            hidden_dim = equi_repr_dim * 2
            self.net = nn.Sequential(
                nn.Linear(equi_repr_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, equi_repr_dim)
            )
            self._init_two_mlp()

        else:
            self.net = nn.Linear(equi_repr_dim, equi_repr_dim)
            self._init_identity_plus_noise()
        
    def _init_identity_plus_noise(self):
        with torch.no_grad():
            eye = torch.eye(self.equi_repr_dim)
            noise = 0.01 * torch.randn_like(self.net.weight)
            self.net.weight.copy_(eye + noise)
            self.net.bias.zero_()

    def _init_two_mlp(self):
        with torch.no_grad():
            first_linear = self.net[0]
            eye = torch.eye(self.equi_repr_dim)
            repeated_eye = eye.repeat_interleave(first_linear.out_features // self.equi_repr_dim, dim=0)
            noise1 = 0.01 * torch.randn_like(first_linear.weight)
            first_linear.weight.copy_(repeated_eye + noise1)
            first_linear.bias.zero_()

            second_linear = self.net[2]
            second_linear.weight.normal_(mean=0.0, std=0.01)
            second_linear.bias.zero_()
        
    def forward(self, r: torch.Tensor):

        out = self.net(r)
        
        return out
        
        