"""f-k Migration Reconstruction of the Zaragoza Bunny"""
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pytorch3d.io import load_obj
from totri.types import C0
from totri.data import repo_path, Zaragoza
from totri.reconstruct import FkMigration
from totri.util import UnitDepthMapRender

# Load dataset
ds_file_name = '/tmp/bunny_zaragoza_rbf_data.pt'
if not os.path.isfile(ds_file_name):
    ds = Zaragoza(
        repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_s[64]_l[64]_gs[0.60]_conf.hdf5'),
        rectify=True,
        device='cuda')
    torch.save(ds, ds_file_name)
else:
    ds = torch.load(ds_file_name)

# Fk migration
fk_migration = FkMigration(
    depth = ds.bin_def.num * ds.bin_def.width,
    width = ds.scan_point_extent[0],
    height = ds.scan_point_extent[1])
volume = fk_migration(ds.transient)[0]

# Maximum intensity projection
intensity, depth = volume.max(dim=0)
intensity = intensity.flip([0,1])
depth = depth.flip([0,1])
depth = depth * ds.bin_def.width / 2

# Load Mesh
obj_path = repo_path('data/zaragoza/bunny_l[0.00,-0.50,0.00]_r[1.57,0.00,3.14]_v[0.21]_bunny.obj')
verts, faces, aux = load_obj(obj_path, device='cuda')
verts = torch.stack((verts[:,0], verts[:,2], -verts[:,1]), dim=-1)
faces = faces.verts_idx

# Render depth map
depth_gt, mask_hi = UnitDepthMapRender(256, ds.scan_point_extent[0], device='cuda').apply(verts, faces)
depth_gt, mask_gt = UnitDepthMapRender( 64, ds.scan_point_extent[0], device='cuda').apply(verts, faces)

# Maximize IOU
mask = 0
best_iou = 0
for threshold in torch.linspace(intensity.min(), intensity.max(), 1000):
    mask_test = intensity >= threshold
    intersection = (mask_test & mask_gt).sum().item()
    union = (mask_test | mask_gt).sum().item()
    if union > 0:
        iou = intersection / union
        if iou > best_iou:
            best_iou = iou
            mask = mask_test
mask_intersection = (mask & mask_gt)

# Depth error
depth_error = depth - depth_gt
depth_error[mask_intersection == False] = np.inf
max_abs_error = depth_error[mask_intersection].abs().max().item()

# Margin hi res 
margin = (
    mask_hi[ :-1,  :-1].int() +
    mask_hi[ :-1, 1:  ].int() +
    mask_hi[1:  ,  :-1].int() +
    mask_hi[1:  , 1:  ].int()
    )
margin = (margin > 0) & (margin < 4)
margin = margin.float().cpu()
margin[margin == 0] = np.inf

# Plot
imshow_args = {
    'interpolation': 'nearest',
    'extent': [
        -ds.scan_point_extent[0]/2,
        ds.scan_point_extent[0]/2,
        -ds.scan_point_extent[1]/2,
        ds.scan_point_extent[1]/2]
}
plt.figure(figsize=(0.75*6, 0.75*4.5))
plt.imshow(mask.float().cpu(), cmap=ListedColormap(['white', 'black']), **imshow_args)
plt.imshow(depth_error.cpu(), vmin=-max_abs_error, vmax=max_abs_error, **imshow_args)
plt.colorbar().set_label('error [m]', rotation=270)
plt.imshow(margin, cmap=ListedColormap(['red']), **imshow_args)
plt.xlabel('m')
plt.ylabel('m')
plt.savefig(repo_path('samples/wacv/bunny_zaragoza/bunny_zaragoza_fk/bunny_error_fk.png'), bbox_inches='tight', pad_inches=0)
plt.savefig(repo_path('samples/wacv/bunny_zaragoza/bunny_zaragoza_fk/bunny_error_fk.eps'), bbox_inches='tight', pad_inches=0)


print(f'IOU = {best_iou}')
print(f'MAE = {depth_error[mask_intersection].abs().mean().item()}')
