import torch
import torch.nn as nn

import numpy as np

import sys
sys.path.append('../../../code')
sys.path.append('../../models')

from utils.sketch_utils import Sketches

from options import RendererOptions
from render_wrapper import Renderer

class SilhouetteLoss(nn.Module):
    def __init__(self, silhouette_canvases):
        super().__init__()
        self.silhouette_canvases = silhouette_canvases

    def forward(self, weights_sum):
        #maximze log probability of weights_sum
        #weights_sum: (B, H, W)
        #self.sihouette_canvases: (B, H, W)

        #clamp weights_sum to avoid log(0)
        weights_sum = weights_sum.clamp(1e-5, 1 - 1e-5)
        loss = -torch.mean(torch.log(weights_sum) * self.silhouette_canvases)

        return loss



if __name__ == '__main__':
    renderer = Renderer(RendererOptions())

    sketches = Sketches('../../../data/plant', H = 128, W = 128, type='manual', preprocess_sketch = True)
    canvas, poses, intrinsics, bboxes, silhouettes, rays = sketches.get_sketches()

    outputs = renderer.model.render(rays['rays_o'], rays['rays_d'], staged=False, bg_color= None, perturb=True, force_all_rays=False)
    weights_sum = outputs['weights_sum'].reshape(-1, canvas.shape[1], canvas.shape[2])

    silhouette_loss = SilhouetteLoss(canvas)
    silhouette_loss(weights_sum)
    #weights_sum = renderer.get_image(poses, intrinsics, output_type='silhouette')
    a = 0
