import torch
import torch.nn as nn
import torch.nn.functional as F

from activation import trunc_exp
from renderer import NeRFRenderer

from encoding import get_encoder

import sys
sys.path.append('../')
from utils.nerf_utils import safe_normalize, MLP



class NeRFNetwork(NeRFRenderer):
    def __init__(self,
                 opt):

        super().__init__(opt)
        #initialize encoder
        self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, log2_hashmap_size=16, desired_resolution=2048 * self.bound)

        #initialize geometry network
        self.geom_feature_size = opt.geom_feature_size
        self.num_layers_geom = opt.num_layers_geom
        self.hidden_dim = opt.hidden_dim

        self.sigma_net = MLP(dim_in = self.in_dim, dim_out = 1 + self.geom_feature_size, dim_hidden = self.hidden_dim, num_layers = self.num_layers_geom, bias = True)

        #initialize rgb network
        self.num_layers_color = opt.num_layers_color
        self.color_net = MLP(dim_in = self.geom_feature_size, dim_out = 3, dim_hidden = self.hidden_dim, num_layers = self.num_layers_color, bias = True)

        #initialize latent network
        self.latent_net = MLP(dim_in = self.geom_feature_size, dim_out = 4, dim_hidden = self.hidden_dim, num_layers = self.num_layers_color, bias = True)

        #initialize background network
        if self.bg_radius > 0:
            self.num_layers_bg = opt.num_layers_bg
            self.hidden_dim_bg = opt.hidden_dim_bg

            # use a very simple network to avoid it learning the prompt...
            # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
            self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=4)

            self.bg_net = MLP(self.in_dim_bg, 4, self.hidden_dim_bg, self.num_layers_bg, bias=True)

        else:
            self.bg_net = None


    # add a density blob to the scene center
    def gaussian(self, x):
        # x: [B, N, 3]

        d = (x ** 2).sum(-1)
        g = 5 * torch.exp(-d / (2 * 0.2 ** 2))

        return g

    def get_sigma(self, x):
        # x: [N, 3]
        #only returns sigma
        h = self.encoder(x, bount = self.bound)
        h = self.sigma_net(h)[..., 0]
        sigma = trunc_exp(h + self.gaussian(x))
        return sigma


    def common_forward(self, x, type = 'rgb'):
        # x: [N, 3], in [-bound, bound]
        # type 'rgb' or 'latent'
        # sigma
        h = self.encoder(x, bound = self.bound)
        h = self.sigma_net(h)
        sigma = trunc_exp(h[..., 0] + self.gaussian(x))
        geom_feature = h[..., 1:]
        #color
        if type == 'rgb':
            color = self.color_net(geom_feature)
            color = torch.sigmoid(color)
            color = torch.cat([color, torch.zeros((color.shape[0], 1), device=color.device)], axis=1)

        else:
            color = self.latent_net(geom_feature)
        return sigma, color






    # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
    def finite_difference_normal(self, x, epsilon=1e-2):
        # x: [N, 3]
        dx_pos, _ = self.get_sigma((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
        dx_neg, _ = self.get_sigma((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
        dy_pos, _ = self.get_sigma((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
        dy_neg, _ = self.get_sigma((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
        dz_pos, _ = self.get_sigma((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
        dz_neg, _ = self.get_sigma((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))

        normal = torch.stack([
            0.5 * (dx_pos - dx_neg) / epsilon,
            0.5 * (dy_pos - dy_neg) / epsilon,
            0.5 * (dz_pos - dz_neg) / epsilon
        ], dim=-1)

        return -normal
    def autograd_normal(self, x):
        # x: [N, 3]
        x.requires_grad_(True)
        sigma, _ = self.common_forward(x)
        grad_output = torch.ones_like(sigma, requires_grad=False, device=x.device)
        gradients = torch.autograd.grad(
            outputs = sigma,
            inputs = x,
            grad_outputs = grad_output,
            create_graph = False,
            retain_graph = False,
            only_inputs = True
        )[0]
        normal = -gradients
        normal = safe_normalize(normal)
        normal = torch.nan_to_num(normal)
        return normal

    def normal(self, x):

        normal = self.finite_difference_normal(x)
        normal = safe_normalize(normal)
        normal[torch.isnan(normal)] = 0

        return normal


    def forward(self, x, d, l=None, ratio=1, shading='albedo'):
        # x: [N, 3], in [-bound, bound]
        # d: [N, 3], view direction, nomalized in [-1, 1]
        # l: [3], plane light direction, nomalized in [-1, 1]
        # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)

        if shading == 'albedo':
            # no need to query normal
            sigma, color = self.common_forward(x)
            normal = None

        else:
            # query normal

            sigma, albedo = self.common_forward(x)
            normal = self.normal(x)
            #normal = self.autograd_normal(x)

            # lambertian shading
            lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,]

            if shading == 'textureless':
                color = lambertian.unsqueeze(-1).repeat(1, 3)
            elif shading == 'normal':
                color = (normal + 1) / 2
                color = torch.clamp(color, 0, 1)
            else: # 'lambertian'
                color = albedo * lambertian.unsqueeze(-1)

            color = torch.cat([color, torch.zeros((color.shape[0], 1), device=color.device)], axis=1)

        return sigma, color, normal


    def density(self, x):
        # x: [N, 3], in [-bound, bound]

        sigma = self.get_sigma(x)

        return {
            'sigma': sigma,
        }


    def background(self, d):

        h = self.encoder_bg(d) # [N, C]

        h = self.bg_net(h)

        # sigmoid activation for rgb
        if not self.latent_mode:
            rgbs = torch.sigmoid(h)
        else:
            rgbs = h

        return rgbs

    # optimizer utils
    def get_params(self, lr):

        params = [
            {'params': self.encoder.parameters(), 'lr': lr * 10},
            {'params': self.sigma_net.parameters(), 'lr': lr},
            {'params': self.color_net.parameters(), 'lr': lr},
            {'params': self.latent_net.parameters(), 'lr': lr},
        ]

        if self.bg_radius > 0:
            params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
            params.append({'params': self.bg_net.parameters(), 'lr': lr})

        return params

if __name__ == '__main__':
    import sys
    sys.path.append('../')
    from options import TrainNGPOptions
    model = NeRFNetwork(TrainNGPOptions())
    print(model)