import numpy as np
import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLP(torch.nn.Sequential):

    def __init__(self, MLP_configs, bias=True, dropout = 0.0):
        super().__init__()

        in_channels=MLP_configs['in_channels'] 
        hidden_channels=MLP_configs['hidden_channels']
        self.mlp_bias=MLP_configs['mlp_bias']
        activation_layer=MLP_configs['activation_layer']

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
            layers.append(activation_layer())
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(torch.nn.Dropout(dropout))
        
        self.layers = nn.Sequential(*layers)
        self.layers.apply(self.init_weights)
        
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.001)
            torch.nn.init.constant_(m.bias, self.mlp_bias)

    def forward(self, x):
        out = self.layers(x)
        return out


class SineLayer(nn.Module):

    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input, a_param, b_param, c_param, d_param):
        output = self.linear(input)
        output = torch.exp(a_param) * torch.sin(torch.exp(b_param) * self.omega_0 * output + c_param) + d_param
        return output
    

class NormalSineLayer(nn.Module):

    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        output = self.linear(input)
        output = torch.sin(self.omega_0 * output)
        return output


class SecoINR(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features,
                 outermost_linear=True, first_omega_0=30, hidden_omega_0=30, 
                 MLP_configs={'in_channels': 64, 'hidden_channels': [64, 32, 4], 'activation_layer': nn.SiLU},
                 num_classes=11):
        super().__init__()

        self.num_classes = num_classes
        self.ground_truth = MLP_configs['GT']
        self.nonlin = SineLayer
        self.normalnonlin = NormalSineLayer
        self.hidden_layers = hidden_layers

        # Conditioner Network
        self.aux_mlp = MLP(MLP_configs)
        
        # Adaptive SIREN Network
        self.net = []
        self.net.append(self.nonlin(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))
        for i in range(hidden_layers):
            self.net.append(self.nonlin(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        if outermost_linear:
            dtype = torch.float
            final_linear = nn.Linear(hidden_features,
                                     out_features,
                                     dtype=dtype)
            with torch.no_grad():
                const = np.sqrt(6/hidden_features)/max(hidden_omega_0, 1e-12)
                final_linear.weight.uniform_(-const, const)
                    
            self.net.append(final_linear)
        else:
            self.net.append(self.nonlin(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
                
        # Pixel Class Representation Network
        self.segnet = []
        in_features = 2
        out_features = int(self.num_classes+1)
        self.segnet.append(self.normalnonlin(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))
        for i in range(hidden_layers):
            self.segnet.append(self.normalnonlin(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        if outermost_linear:
            dtype = torch.float
            final_linear = nn.Linear(hidden_features,
                                     out_features,
                                     dtype=dtype)
            with torch.no_grad():
                const = np.sqrt(6/hidden_features)/max(hidden_omega_0, 1e-12)
                final_linear.weight.uniform_(-const, const)
                    
            self.segnet.append(final_linear)
        else:
            self.segnet.append(self.nonlin(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        self.segnet = nn.Sequential(*self.segnet)
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, coords):

        # Pixel Class Representation Network
        seg_output = self.segnet[0](coords)
        for i in range(1, self.hidden_layers + 1):
            seg_output = self.segnet[i](seg_output)
        seg_output = self.segnet[self.hidden_layers + 1](seg_output)
                
        # Conditioner Network
        coefs = self.aux_mlp(self.softmax(seg_output[0])).reshape(-1, 4, 4)
        coef = coefs[:,0,:]

        # Adaptive SIREN Network
        a_param = coef[..., 0][None, ..., None]
        b_param = coef[..., 1][None, ..., None]
        c_param = coef[..., 2][None, ..., None]
        d_param = coef[..., 3][None, ..., None]
        output = self.net[0](coords, a_param, b_param, c_param, d_param)
        for i in range(1, self.hidden_layers + 1):
            coef = coefs[:,i,:]
            a_param = coef[..., 0][None, ..., None]
            b_param = coef[..., 1][None, ..., None]
            c_param = coef[..., 2][None, ..., None]
            d_param = coef[..., 3][None, ..., None]
            output = self.net[i](output, a_param, b_param, c_param, d_param)
        output = self.net[self.hidden_layers + 1](output)

        return [output, coefs, seg_output]