"""Optimize depth map of the diffuse S"""
import os
import time
import torch
from torchvision.utils import save_image
from totri.data import repo_path, OTooleConfocal2018
from totri.types import VolumeDef
from totri.util import datetime_str
from totri.fitting.depth_map import DepthMapFitting, DepthMap, MultiplicativeInterpolator
from totri.fitting.background import TransientBackgroundConfocal
from totri.util.render import UnitCubeMeshRender

def make_data():
    os.makedirs(repo_path(f'samples/wacv/diffuse_s/flatfield'))
    ds = OTooleConfocal2018(repo_path('data/measurement/diffuse_s/data_diffuse_s.mat'), device='cuda')
    ds.downsample_temporal(8)
    ds.transient = ds.transient[:,:,::2,::2].contiguous()
    ds.scan_points = ds.scan_points[:,::2,::2,:].contiguous()
    transient = torch.load(repo_path('data/pfm/diffuse_S.pth')).cuda()
    transient = transient.flip([0,1])

    num_scan_points = ds.transient.shape[2] * ds.transient.shape[3]
    res = 16
    ds.volume_def.resolution = [ds.volume_def.resolution[0], res, res]

    method = DepthMapFitting(
        ds.bin_def,
        tensorboard=False,
        log_dir=repo_path(f'samples/wacv/diffuse_s/flatfield/run'))
    depth_map = DepthMap(ds.volume_def)
    tic = time.perf_counter()
    method.fit(
        depth_map,
        transient.view(-1, num_scan_points), 
        ds.scan_points.view(num_scan_points, -1),
        background_model = TransientBackgroundConfocal(ds.transient.shape[1], (-0.4, -0.4), (0.4, 0.4), clamping_factor=2.0, device='cuda'),
        num_iter = 8000,
        lr_interpolator=MultiplicativeInterpolator(1.e-2, 1.e-3, 0.999),
        lambda_depth=4.0,
        lambda_color=1.0,
        lambda_w=0,
        lambda_black=0,
        warmup=5,
        subdivide_at=[2000, 4000, 6000],
    )
    toc = time.perf_counter()

    verts = depth_map.verts()
    faces = depth_map.faces
    colors = depth_map.colors()

    torch.save({
        'verts': verts.detach().cpu(),
        'faces': faces.detach().cpu(),
        'colors': colors.detach().cpu(),
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/diffuse_s/flatfield/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/diffuse_s/flatfield/data.pth'))
    verts = data['verts']
    faces = data['faces']
    colors = data['colors']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    img = UnitCubeMeshRender(4096, distance=0.105, specular=0).apply(verts.cuda(), faces.cuda(), colors.cuda())
    save_image(img, repo_path(f'samples/wacv/diffuse_s/flatfield/diffuse_s_flatfield.jpg'))

make_data()
make_figures()
