"""Optimize position and rotation of a bunny"""
import os
import time
import torch
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from pytorch3d.transforms import axis_angle_to_quaternion
from totri.types import BinDef
from totri.data.util import make_wall_grid, transient_noise, repo_path
from totri.data.mesh import load_sample_mesh
from totri.fitting.pose import PoseFitting, Translation, Rotation, merge_meshes, apply_transform
from totri.render import MeshRenderConfocal
from totri.util import datetime_str
from totri.util.render import UnitCubeMeshRender

def make_data():
    torch.manual_seed(42)
    os.makedirs(repo_path(f'samples/wacv/tracking/bunny'))

    # Load Mesh
    bin_def=BinDef(2/256, 1.5, 256, unit_is_time=False)
    resolution=32
    scan_points = make_wall_grid(
        -1, 1, resolution,
        -1, 1, resolution,
        z_val=0)
    verts, faces = load_sample_mesh('bunny')
    verts[:,:,2] -= 1.4
    verts_list = [verts[0],]
    faces_list = [faces[0],]

    # Transform and render
    verts_transformed = verts.clone()
    verts_transformed[:,:,0] += 0.25
    verts_transformed[:,:,1] += 0.25
    verts_transformed[:,:,2] += 1.4
    verts_transformed_list = [verts_transformed[0],]
    transient_gt = MeshRenderConfocal.apply(
        verts_transformed, faces, scan_points[None],
        bin_def, None, None)[0]
    transient_gt = transient_noise(transient_gt)
    img_gt = UnitCubeMeshRender(1024).apply(verts_transformed[0], faces[0])

    # Initial translation and rotation
    translation = Translation((-0.25, -0.25, 1.4))
    rotation_axis = torch.tensor([-1.0, 0.5, -2.0])
    rotation_axis = -0.5 * rotation_axis / (rotation_axis**2).sum()**0.5
    quaternion = axis_angle_to_quaternion(rotation_axis)
    rotation = Rotation(quaternion.tolist())
    translation_list = [translation,]
    rotation_list = [rotation,]

    # Initial transient and img
    verts_init, faces_init = merge_meshes(apply_transform(verts_list, translation_list, rotation_list), faces_list)
    transient_init = MeshRenderConfocal.apply(
        verts_init[None], faces_init[None], scan_points[None],
        bin_def, None, None)[0]
    img_init = UnitCubeMeshRender(1024).apply(verts_init, faces_init)

    tic = time.perf_counter()
    # Optimize
    translation_list, rotation_list = PoseFitting(
        bin_def,
        tensorboard=False,
    ).fit(
        transient_gt, scan_points,
        verts_list, faces_list,
        [verts_transformed[0],],
        translation_list = translation_list,
        rotation_list = rotation_list,
        num_iter = 150,
    )
    toc = time.perf_counter()

    # Final transient
    verts_final, faces_final = merge_meshes(apply_transform(verts_list, translation_list, rotation_list), faces_list)
    transient_final = MeshRenderConfocal.apply(
        verts_final[None], faces_final[None], scan_points[None],
        bin_def, None, None)[0]
    img_final = UnitCubeMeshRender(1024).apply(verts_final, faces_final)

    torch.save({
        'transient_gt': transient_gt.detach().cpu(),
        'transient_init': transient_init.detach().cpu(),
        'transient_final': transient_final.detach().cpu(),
        'img_gt': img_gt.detach().cpu(),
        'img_init': img_init.detach().cpu(),
        'img_final': img_final.detach().cpu(),
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/tracking/bunny/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/tracking/bunny/data.pth'))
    transient_init = data['transient_init']
    transient_final = data['transient_final']
    transient_gt = data['transient_gt']
    img_gt = data['img_gt']
    img_init = data['img_init']
    img_final = data['img_final']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    # Images
    save_image(img_gt, repo_path(f'samples/wacv/tracking/bunny/gt.jpg'))
    save_image(img_init, repo_path(f'samples/wacv/tracking/bunny/initial.jpg'))
    save_image(img_final, repo_path(f'samples/wacv/tracking/bunny/final.jpg'))

    # Plot Transients
    for t, tn in [(transient_init, 'init'), (transient_final, 'final'), (transient_gt, 'gt')]:
        fig = plt.figure(figsize=(4,4))
        plt.imshow(t)
        plt.axis('off')
        plt.tight_layout()

        fig.savefig(repo_path(f'samples/wacv/tracking/bunny/transients_{tn}.jpg'), bbox_inches='tight', dpi=320, pad_inches=0)

make_data()
make_figures()