"""Optimize Rbf for the mannequin measurement"""
import os
import time
import torch
from torchvision.utils import save_image
from totri.types import VolumeDef
from totri.data import VeltenNc2012, repo_path
from totri.util import datetime_str
from totri.fitting.rbf import GaussianRbfBase, RbfFitting
from totri.fitting.background import TransientBackgroundExhaustive
from totri.util.render import UnitCubeMeshRender

def make_data():
    file_name = repo_path('data/veltennc2012/gandalf/dataset_Gandalf.mat')
    ds = VeltenNc2012(file_name, device='cuda')

    # Crop measurement
    scan_start = 40
    scan_end = scan_start + 224
    ds.transient   = ds.transient[:,:,scan_start:scan_end,:]
    ds.scan_points = ds.scan_points[:,scan_start:scan_end,:]

    ds.volume_def.start = (-0.75, 0.0, 1.25)
    ds.volume_def.end   = ( 0.75, 3.0, 2.75)
    ds.volume_def.resolution = (24, 48, 24)

    # Optimize
    print('Start fitting')
    tic = time.perf_counter()
    base = RbfFitting(
        ds.bin_def,
        ds.scan_points[0],
        ds.laser_points[0],
        ds.scan_origin,
        ds.laser_origin,
        optimize_check_delete=True,
        delete_Factor=1.002,
        tensorboard=True,
        video=True,
        log_dir=repo_path(f'samples/wacv/mannequin/measurement/{datetime_str()}')
    ).fit(
        GaussianRbfBase(
            ds.volume_def,
            has_color=False,
            sigma_init=0.07,
            sigma_min=0.03),
        ds.transient[0],
        TransientBackgroundExhaustive(
            ds.bin_def.num,
            (-1.5611, 0.6892),
            ( 1.0337, 4.0229),
            device='cuda'),
        subdivide_at=[53],
        num_iter=60,
    )
    toc = time.perf_counter()
    verts = base.make_verts()

    torch.save({
        'verts': verts.detach().cpu(),
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/mannequin/measurement/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/mannequin/measurement/data.pth'))
    verts = data['verts'].cuda()
    time_s = data['time_s']

    verts[:,0] = -verts[:,0]
    verts = torch.flip(verts, [0,])

    print(f'Runtime: {time_s}s')

    volume_def = VolumeDef(
        (-0.75, 0.0, 1.25),
        ( 0.75, 3.0, 2.75),
        (24, 48, 24),)
    img = UnitCubeMeshRender(4096, distance=0.07, specular=0).apply(verts, volume_def=volume_def)
    save_image(img, repo_path(f'samples/wacv/mannequin/measurement/mannequin_measurement_rbf.jpg'))

make_data()
make_figures()
