"""Rbf Reconstruction of the Zaragoza Bunny Sinogram"""
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pytorch3d.io import load_obj
from torchvision.utils import save_image
from totri.types import VolumeDef
from totri.util import datetime_str
from totri.data import repo_path, Zaragoza
from totri.fitting.rbf import GaussianRbfBase, RbfFitting
from totri.util import UnitDepthMapRender, faces_of_verts
from totri.util.render import UnitCubeMeshRender

def make_data():
    os.makedirs(repo_path(f'samples/wacv/bunny_zaragoza/sinogram_rbf'))
    ds_file_name = '/tmp/bunny_zaragoza_rbf_data.pt'
    if not os.path.isfile(ds_file_name):
        ds = Zaragoza(
            repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_s[64]_l[64]_gs[0.60]_conf.hdf5'),
            rectify=True,
            device='cuda')
        torch.save(ds, ds_file_name)
    else:
        ds = torch.load(ds_file_name)

    scan_points = ds.scan_points.view(-1,3)
    transient = ds.transient.view(ds.transient.shape[1], -1)
    sinogram_scan_points = []
    sinogram_transient = []
    radius = 0.3
    num_samples = 4
    for i in range(360):
        # Angle
        angle = i / 360 * 2 * np.pi
        # Corresponding position
        pos = torch.tensor([[
            radius * np.sin(angle),
            radius * np.cos(angle),
            0.0]],
            device=scan_points.device,
            dtype=scan_points.dtype)
        # Find closest scan points
        distances = (((scan_points - pos)**2).sum(dim=1)**0.5)
        values, indices = torch.topk(distances, num_samples, dim=0, largest=False)
        # Interpolate
        weights = 1 / values.clamp(min=1.e-8)
        weights = weights / weights.sum()
        sinogram_transient.append((transient[:,indices] * weights.view(1, num_samples)).sum(dim=1))
        sinogram_scan_points.append(pos[0])
    # Build tensors
    sinogram_transient = torch.stack(sinogram_transient, dim=1)
    sinogram_scan_points = torch.stack(sinogram_scan_points, dim=0)

    resolution=32
    volume_def=VolumeDef([-0.3, -0.3, 0.2], [0.3, 0.3, 0.8], [resolution,]*3)

    torch.manual_seed(42)
    tic = time.perf_counter()
    base = RbfFitting(
        ds.bin_def,
        sinogram_scan_points,
        tensorboard=True,
        delete_Factor=1.01,
        optimize_check_delete=True,
        log_dir=repo_path(f'samples/wacv/bunny_zaragoza/sinogram_rbf/run')
    ).fit(
        GaussianRbfBase(
            volume_def,
            has_color=False,
            sigma_init=0.05),
        sinogram_transient,
        subdivide_at=[33,53],
        num_iter=60,
    )
    toc = time.perf_counter()
    verts = base.make_verts()

    torch.save({
        'verts': verts.detach().cpu(),
        'scan_point_extent': ds.scan_point_extent,
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/bunny_zaragoza/sinogram_rbf/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/bunny_zaragoza/sinogram_rbf/data.pth'))
    verts = data['verts'].cuda()
    scan_point_extent = data['scan_point_extent']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    # Load Mesh
    obj_path = repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_bunny.obj')
    verts_gt, faces_gt, aux = load_obj(obj_path, device='cuda')
    verts_gt = torch.stack((verts_gt[:,0], verts_gt[:,2], -verts_gt[:,1]), dim=-1)
    faces_gt = faces_gt.verts_idx

    # Render depth map
    depth_gt, mask_gt = UnitDepthMapRender(256, scan_point_extent[0], device='cuda').apply(verts_gt, faces_gt)
    depth, mask =       UnitDepthMapRender(1024, scan_point_extent[0], device='cuda').apply(verts, faces_of_verts(verts))

    depth = torch.stack(
        [depth[x::4,y::4] for x in range(3) for y in range(3)],
        dim=0).median(dim=0).values
    mask = torch.stack(
        [mask[x::4,y::4].float() for x in range(3) for y in range(3)],
        dim=0).median(dim=0).values.bool()

    mask_intersection = (mask & mask_gt)

    # Depth error
    depth_error = depth - depth_gt
    depth_error[mask_intersection == False] = np.inf
    max_abs_error = depth_error[mask_intersection].abs().max().item()
    print(max_abs_error)

    # Margin hi res 
    margin = (
        mask_gt[ :-1,  :-1].int() +
        mask_gt[ :-1, 1:  ].int() +
        mask_gt[1:  ,  :-1].int() +
        mask_gt[1:  , 1:  ].int()
        )
    margin = (margin > 0) & (margin < 4)
    margin = margin.float().cpu()
    margin[margin == 0] = np.inf

    # Plot
    imshow_args = {
        'interpolation': 'nearest',
        'extent': [
            -scan_point_extent[0]/2,
            scan_point_extent[0]/2,
            -scan_point_extent[1]/2,
            scan_point_extent[1]/2]
    }
    plt.figure(figsize=(0.75*6, 0.75*4.5))
    plt.imshow(mask.float().cpu(), cmap=ListedColormap(['white', 'black']), **imshow_args)
    plt.imshow(depth_error.cpu(), vmin=-max_abs_error, vmax=max_abs_error, **imshow_args)
    plt.colorbar().set_label('error [m]', rotation=270)
    plt.imshow(margin, cmap=ListedColormap(['red']), **imshow_args)
    plt.xlabel('m')
    plt.ylabel('m')
    plt.savefig(repo_path('samples/wacv/bunny_zaragoza/sinogram_rbf/error_sinogram_rbf.png'), bbox_inches='tight', pad_inches=0)
    plt.savefig(repo_path('samples/wacv/bunny_zaragoza/sinogram_rbf/error_sinogram_rbf.eps'), bbox_inches='tight', pad_inches=0)

    img = UnitCubeMeshRender(4096, distance=0.07, specular=0).apply(verts)
    save_image(img, repo_path(f'samples/wacv/bunny_zaragoza/sinogram_rbf/bunny_sinogram_rbf.jpg'))

    intersection = (mask & mask_gt).sum().item()
    union = (mask | mask_gt).sum().item()
    iou = intersection / union
    print(f'IOU = {iou}')
    print(f'MAE = {depth_error[mask_intersection].abs().mean().item()}')

make_data()
make_figures()
