import torch
from torch import nn

from utils.general import ce_pq_loss

class ShapeLoss(nn.Module):
    def __init__(self, proximal_surface, sketch_shape):
        super().__init__()
        self.sketch_shape = sketch_shape
        self.proximal_surface = proximal_surface
        self.delta = 0.2
    def forward(self, xyzs, sigmas):
        mesh_occ = self.sketch_shape.winding_number(xyzs)
        if self.proximal_surface > 0:
            weight = 1 - self.sketch_shape.gaussian_weighted_distance(xyzs, self.proximal_surface)
        else:
            weight = None
        indicator = (mesh_occ > 0.5).float()
        nerf_occ = 1 - torch.exp(-self.delta * sigmas)
        nerf_occ = nerf_occ.clamp(min=0, max=1.1)
        loss = ce_pq_loss(nerf_occ, indicator, weight=weight)  # order is important for CE loss + second argument may not be optimized
        return loss