"""Optimize depth map of the statue (64x64)"""
import os
import time
import torch
from torchvision.utils import save_image
from totri.types import VolumeDef
from totri.util import datetime_str
from totri.data import repo_path, LindellFk2019
from totri.fitting.depth_map import DepthMapFitting, DepthMap, TransientBackgroundConfocal, MultiplicativeInterpolator
from totri.util.render import UnitCubeMeshRender

def make_data():
    os.makedirs(repo_path(f'samples/wacv/statue/statue_64'))
    # Load data
    ds_file_name = '/tmp/statue_depth_map_data.pt'
    if not os.path.isfile(ds_file_name):
        ds = LindellFk2019(
            repo_path('data/measurement/statue/meas_180min.mat'),
            repo_path('data/measurement/statue/tof.mat'),
            device='cuda')
        torch.save(ds, ds_file_name)
    else:
        ds = torch.load(ds_file_name)

    # Downsample
    ds.transient = ds.transient[:,:,::8,::8].contiguous()
    ds.scan_points = ds.scan_points[:,::8,::8].contiguous()
    num_scan_points = ds.transient.shape[2] * ds.transient.shape[3]

    # Fit depth map
    scale = 0.8
    ds.volume_def.start = (-0.6*scale, -0.8*scale, 0)
    ds.volume_def.end   = ( 0.6*scale,  0.8*scale, 2)
    res = 64
    ds.volume_def.resolution = [ds.volume_def.resolution[0], res, res-res//4]
    method = DepthMapFitting(
        ds.bin_def,
        tensorboard=False,
        log_dir=repo_path(f'samples/wacv/statue/statue_64/run'))
    depth_map = DepthMap(ds.volume_def)
    tic = time.perf_counter()
    method.fit(
        depth_map,
        ds.transient.view(-1, num_scan_points), 
        ds.scan_points.view(num_scan_points, -1),
        background_model = TransientBackgroundConfocal(ds.transient.shape[1], (-1, -1), (1, 1), clamping_factor=1.0, device='cuda'),
        # scan_point_limit=4096,
        num_iter = 900,
        lr_interpolator=MultiplicativeInterpolator(2.e-2, 1.e-4, 0.995),
        lambda_depth=1.0,
        lambda_color=1.3,
        lambda_w=0,
        lambda_black=0,
        warmup=5,
        subdivide_at=[300, 800],
    )
    toc = time.perf_counter()

    verts = depth_map.verts()
    faces = depth_map.faces
    colors = depth_map.colors()


    torch.save({
        'verts': verts.detach().cpu(),
        'faces': faces.detach().cpu(),
        'colors': colors.detach().cpu(),
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/statue/statue_64/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/statue/statue_64/data.pth'))
    verts = data['verts']
    faces = data['faces']
    colors = data['colors']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    img = UnitCubeMeshRender(1024, distance=0.05, specular=0).apply(verts.cuda(), faces.cuda(), colors.cuda())
    img[:,:,:120] = 0
    img[:,:,-120:] = 0
    img = img / img.max()
    save_image(img, repo_path(f'samples/wacv/statue/statue_64/statue_64.jpg'))

    colors = (colors > 0.11).float()
    img = UnitCubeMeshRender(1024, distance=0.05, specular=0).apply(verts.cuda(), faces.cuda(), colors.cuda())
    img[:,:,:120] = 0
    img[:,:,-120:] = 0
    img = img / img.max()
    save_image(img, repo_path(f'samples/wacv/statue/statue_64/statue_64_b.jpg'))

make_data()
make_figures()
