"""Optimize colored Rbf for Spot Model"""
import os
import time
import torch
from torchvision.utils import save_image
from totri.types import BinDef, VolumeDef
from totri.util import datetime_str
from totri.data import repo_path, load_sample_mesh
from totri.data.util import make_wall_grid, transient_noise
from totri.render import MeshRenderConfocal
from totri.fitting.rbf import RbfFitting, GaussianRbfBase
from totri.util.render import UnitCubeMeshRender

def make_data():
    os.makedirs(repo_path(f'samples/wacv/color/spot'))
    verts_gt, colors_gt, faces_gt = load_sample_mesh('spot_hi', 'cuda')
    colors_gt = colors_gt.mean(dim=-1,keepdim=True)

    bin_def=BinDef(2*2/512, 0.5, 512, unit_is_time=False)
    resolution_volume = 32
    resolution_grid = 64
    volume_def=VolumeDef([-1, -1, 0], [1, 1, 2], [resolution_volume,]*3)
    scan_points = make_wall_grid(
        -1, 1, resolution_grid,
        -1, 1, resolution_grid,
        z_val=0)

    transient = MeshRenderConfocal.apply(
            verts_gt, faces_gt,
            scan_points[None],
            bin_def,
            colors_gt,
            None)[0]
    transient = transient_noise(transient, 10, 0.25)

    fitting = RbfFitting(
        bin_def,
        scan_points,
        None, None, None,
        optimize_check_delete=True,
        delete_Factor=1.02,
        tensorboard=False,
        video=False,
        log_dir=repo_path(f'samples/wacv/color/spot/{datetime_str()}'),
        project_transient=False,
    )
    print('Start fitting')
    tic = time.perf_counter()
    base = fitting.fit(
        GaussianRbfBase(
            volume_def,
            has_color=True,
            sigma_init=0.05),
        transient,
        verts_input=verts_gt[0],
        faces_input=faces_gt[0],
        subdivide_at=[33,73],
        num_iter=90,
    )
    toc = time.perf_counter()

    verts = base.make_verts()

    torch.save({
        'verts': verts.detach().cpu(),
        'verts_gt': verts_gt.detach().cpu(),
        'colors_gt': colors_gt.detach().cpu(),
        'faces_gt': faces_gt.detach().cpu(),
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/color/spot/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/color/spot/data.pth'))
    verts = data['verts'].cuda()
    verts_gt = data['verts_gt'].cuda()
    colors_gt = data['colors_gt'].cuda()
    faces_gt = data['faces_gt'].cuda()
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    img = UnitCubeMeshRender(4096, distance=0.5).apply(
        verts[:,:3], None, verts[:,3:].expand(-1, 3))
    save_image(img, repo_path(f'samples/wacv/color/spot/spot_rbf.jpg'))

    img = UnitCubeMeshRender(4096, distance=0.5).apply(
        verts_gt[0], faces_gt[0], colors_gt[0].expand(-1, 3))
    save_image(img, repo_path(f'samples/wacv/color/spot/spot_gt.jpg'))

make_data()
make_figures()
