"""Depth Map Reconstruction of the Zaragoza Bunny"""
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from matplotlib.colors import ListedColormap
from pytorch3d.io import load_obj
from totri.types import VolumeDef
from totri.data import repo_path, Zaragoza
from totri.fitting.depth_map import DepthMapFitting, DepthMap, MultiplicativeInterpolator
from totri.fitting.background import TransientBackgroundConfocal
from totri.util import UnitDepthMapRender
from totri.util.render import UnitCubeMeshRender

def make_data():
    os.makedirs(repo_path(f'samples/wacv/bunny_zaragoza/depth_map'))
    ds_file_name = '/tmp/bunny_zaragoza_rbf_data.pt'
    if not os.path.isfile(ds_file_name):
        ds = Zaragoza(
            repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_s[64]_l[64]_gs[0.60]_conf.hdf5'),
            rectify=True,
            device='cuda')
        torch.save(ds, ds_file_name)
    else:
        ds = torch.load(ds_file_name)

    resolution=16
    volume_def=VolumeDef([-0.3, -0.3, 0.2], [0.3, 0.3, 0.6], [resolution,]*3)

    method = DepthMapFitting(
        ds.bin_def,
        tensorboard=False,
        log_dir=repo_path(f'samples/wacv/bunny_zaragoza/depth_map/run'))
    depth_map = DepthMap(volume_def)
    tic = time.perf_counter()
    method.fit(
        depth_map,
        ds.transient.view(ds.transient.shape[1], -1), 
        ds.scan_points.view(-1, 3), 
        background_model = TransientBackgroundConfocal(ds.transient.shape[1], (-0.3,-0.3), (0.3, 0.3), clamping_factor=0.1, device='cuda'),
        scan_point_limit=-1,
        num_iter = 1700,
        lr_interpolator=MultiplicativeInterpolator(1.e-2, 1.e-3, 0.999),
        lambda_depth=1.0,
        lambda_color=1.0,
        lambda_w=0,
        lambda_black=0.0,
        warmup=5,
        subdivide_at=[500, 1000, 1500],
    )
    toc = time.perf_counter()
    color = depth_map.color.flip([0,1])
    depth = depth_map.abs_depth().flip([0,1])

    verts_dm = depth_map.verts()
    faces_dm = depth_map.faces
    colors_dm = depth_map.colors()

    torch.save({
        'color': color.detach().cpu(),
        'depth': depth.detach().cpu(),
        'verts_dm': verts_dm.detach().cpu(),
        'faces_dm': faces_dm.detach().cpu(),
        'colors_dm': colors_dm.detach().cpu(),
        'scan_point_extent': ds.scan_point_extent,
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/bunny_zaragoza/depth_map/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/bunny_zaragoza/depth_map/data.pth'))
    color = data['color'].cuda()
    depth = data['depth'].cuda()
    verts_dm = data['verts_dm'].cuda()
    faces_dm = data['faces_dm'].cuda()
    colors_dm = data['colors_dm'].cuda()
    scan_point_extent = data['scan_point_extent']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    # Load Mesh
    obj_path = repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_bunny.obj')
    verts, faces, aux = load_obj(obj_path, device='cuda')
    verts = torch.stack((verts[:,0], verts[:,2], -verts[:,1]), dim=-1)
    faces = faces.verts_idx

    # Render depth map
    depth_gt, mask_hi = UnitDepthMapRender(256, scan_point_extent[0], device='cuda').apply(verts, faces)
    depth_gt, mask_gt = UnitDepthMapRender(128, scan_point_extent[0], device='cuda').apply(verts, faces)

    # Threshold
    best_iou = 0
    for threshold in torch.linspace(color.min(), color.max(), 1000):
        mask_test = color >= threshold
        intersection = (mask_test & mask_gt).sum().item()
        union = (mask_test | mask_gt).sum().item()
        if union > 0:
            iou = intersection / union
            if iou > best_iou:
                best_iou = iou
                mask = mask_test
    mask_intersection = (mask & mask_gt)

    # Depth error
    depth_error = depth - depth_gt
    depth_error[mask_intersection == False] = np.inf
    max_abs_error = depth_error[mask_intersection].abs().max().item()

    # Margin hi res 
    margin = (
        mask_hi[ :-1,  :-1].int() +
        mask_hi[ :-1, 1:  ].int() +
        mask_hi[1:  ,  :-1].int() +
        mask_hi[1:  , 1:  ].int()
        )
    margin = (margin > 0) & (margin < 4)
    margin = margin.float().cpu()
    margin[margin == 0] = np.inf

    # Plot
    imshow_args = {
        'interpolation': 'nearest',
        'extent': [
            -scan_point_extent[0]/2,
            scan_point_extent[0]/2,
            -scan_point_extent[1]/2,
            scan_point_extent[1]/2]
    }
    max_abs_error = 0.162
    plt.figure(figsize=(0.75*6, 0.75*4.5))
    plt.imshow(mask.float().cpu(), cmap=ListedColormap(['white', 'black']), **imshow_args)
    plt.imshow(depth_error.cpu(), vmin=-max_abs_error, vmax=max_abs_error, **imshow_args)
    plt.colorbar().set_label('error [m]', rotation=270)
    plt.imshow(margin, cmap=ListedColormap(['red']),  **imshow_args)
    plt.xlabel('m')
    plt.ylabel('m')
    plt.savefig(repo_path('samples/wacv/bunny_zaragoza/depth_map/error_depth_map.png'), bbox_inches='tight', pad_inches=0)
    plt.savefig(repo_path('samples/wacv/bunny_zaragoza/depth_map/error_depth_map.eps'), bbox_inches='tight', pad_inches=0)

    img = UnitCubeMeshRender(4096, distance=0.07, specular=0).apply(verts_dm, faces_dm, colors_dm)
    save_image(img, repo_path(f'samples/wacv/bunny_zaragoza/depth_map/bunny_depth.jpg'))

    print(f'IOU = {best_iou}')
    print(f'MAE = {depth_error[mask_intersection].abs().mean().item()}')

# make_data()
make_figures()
