"""Tracking two armadillos"""
import os
import time
import torch
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from pytorch3d.transforms import axis_angle_to_quaternion
from totri.types import BinDef
from totri.data.util import make_wall_grid, transient_noise, repo_path
from totri.data.mesh import load_sample_mesh
from totri.fitting.pose import PoseVideoFitting, Translation, Rotation, merge_meshes, apply_transform
from totri.render import MeshRenderConfocal
from totri.util import datetime_str
from totri.util.render import UnitCubeMeshRender

def get_rotation(angle):
    rotation_axis = torch.tensor([0.0, 0.0, angle])
    quaternion = axis_angle_to_quaternion(rotation_axis)
    return Rotation(quaternion.tolist())

def make_data():
    torch.manual_seed(42)
    num_steps = 32
    os.makedirs(repo_path(f'samples/wacv/tracking/armadillos/render'), exist_ok=True)

    # Load Mesh
    bin_def=BinDef(2/256, 1.0, 512, unit_is_time=False)
    resolution=32
    scan_points = make_wall_grid(
        -2, 2, 2*resolution,
        -1, 1,   resolution,
        z_val=0, x_fastest=False)
    verts, faces = load_sample_mesh('armadillo')
    verts = verts[0] * 0.75
    faces = faces[0]
    verts_list = [verts, verts]
    faces_list = [faces, faces]

    # Transform and render
    pos_1 = torch.stack((
        torch.linspace(-0.8,  0.8, num_steps, device='cuda'),
        torch.linspace( 0.2, -0.2, num_steps, device='cuda'),
        torch.linspace( 1.0,  1.0, num_steps, device='cuda'),
        ), dim=1) # 64, 3
    pos_2 = torch.stack((
        torch.linspace( 0.8, -0.8, num_steps, device='cuda'),
        torch.linspace( 0.2, -0.2, num_steps, device='cuda'),
        torch.linspace( 1.5,  1.5, num_steps, device='cuda'),
        ), dim=1) # 64, 3
    rot_1 = torch.linspace(0,  1.0, num_steps)
    rot_2 = torch.linspace(0, -1.0, num_steps)
    rotation_lists_ref = []
    transient_gt_noisy_list = []
    transient_gt_list = []
    img_list = []
    psnr_list = []
    for i in range(num_steps):
        translation_list = [Translation(pos_1[i].tolist()), Translation(pos_2[i].tolist())]
        rotation_list = [get_rotation(rot_1[i]), get_rotation(rot_2[i])]
        rotation_lists_ref.append(rotation_list)
        if i == 0:
            translation_list_0 = translation_list
            rotation_list_0 = rotation_list
        with torch.no_grad():
            verts_list_transformed, faces_list_transformed = merge_meshes(apply_transform(verts_list, translation_list, rotation_list), faces_list)
            transient_gt = MeshRenderConfocal.apply(
                verts_list_transformed[None], faces_list_transformed[None],
                scan_points[None],
                bin_def, None, None)[0]
            transient_gt_noisy = transient_noise(transient_gt, scale=0.1, bias=2.5)
            transient_gt_list.append(transient_gt)
            transient_gt_noisy_list.append(transient_gt_noisy)
            img = UnitCubeMeshRender(1024, distance=1.5).apply(verts_list_transformed, faces_list_transformed)
            img_list.append(img)

    tic = time.perf_counter()
    # Optimize
    translation_lists, rotation_lists = PoseVideoFitting(
        bin_def,
    ).fit(
        transient_gt_noisy_list, scan_points,
        verts_list, faces_list,
        translation_list_0 = translation_list_0,
        rotation_list_0 = rotation_list_0,
        num_iter = 100,
    )
    toc = time.perf_counter()

    # Get reconstructed translations
    pos_1_optim = torch.cat(
        [t[0].translation for t in translation_lists],
        dim=0)
    pos_2_optim = torch.cat(
        [t[1].translation for t in translation_lists],
        dim=0)

    torch.save({
        'transient_gt_list': [t.detach().cpu() for t in  transient_gt_list],
        'transient_gt_noisy_list': [t.detach().cpu() for t in  transient_gt_noisy_list],
        'img_list': [t.detach().cpu() for t in  img_list],
        'pos_1': pos_1.detach().cpu(),
        'pos_2': pos_2.detach().cpu(),
        'pos_1_optim': pos_1_optim.detach().cpu(),
        'pos_2_optim': pos_2_optim.detach().cpu(),
        'rotation_lists_ref': rotation_lists_ref,
        'rotation_lists': rotation_lists,
        'time_s': toc - tic,
        }, repo_path(f'samples/wacv/tracking/armadillos/data.pth'))

def make_figures():
    data = torch.load(repo_path(f'samples/wacv/tracking/armadillos/data.pth'))
    transient_gt_list = data['transient_gt_list']
    transient_gt_noisy_list = data['transient_gt_noisy_list']
    img_list = data['img_list']
    pos_1 = data['pos_1']
    pos_2 = data['pos_2']
    pos_1_optim = data['pos_1_optim']
    pos_2_optim = data['pos_2_optim']
    time_s = data['time_s']

    print(f'Runtime: {time_s}s')

    t0 = torch.stack(transient_gt_list)
    t1 = torch.stack(transient_gt_noisy_list)
    mse = torch.mean((t0 - t1) ** 2)
    psnr = 20 * torch.log10(max(t0.max().item(), t1.max().item()) / torch.sqrt(mse))
    print(f'PSNR: {psnr}')

    scale = 0.75
    figsize = (4*scale, 3*scale)

    fig = plt.figure(figsize=figsize)
    plt.plot(pos_1_optim[:,0], 'r-', linewidth=2)
    plt.plot(pos_1_optim[:,1], 'g-', linewidth=2)
    plt.plot(pos_1_optim[:,2], 'b-', linewidth=2)
    plt.plot(pos_1[:,0], '--', color='#000000', linewidth=1)
    plt.plot(pos_1[:,1], '--', color='#000000', linewidth=1)
    plt.plot(pos_1[:,2], '--', color='#000000', linewidth=1)
    plt.xlabel('Frame')
    plt.ylabel('Position [m]')
    ax = plt.gca()
    ax.set_ylim([-0.9, 1.6])
    plt.tight_layout()
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/coords_1.eps'), bbox_inches='tight', dpi=160, pad_inches=0)
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/coords_1.png'), bbox_inches='tight', dpi=160, pad_inches=0)

    fig = plt.figure(figsize=figsize)
    plt.plot(pos_2_optim[:,0], 'r-', linewidth=2, label='x')
    plt.plot(pos_2_optim[:,1], 'g-', linewidth=2, label='y')
    plt.plot(pos_2_optim[:,2], 'b-', linewidth=2, label='z')
    plt.plot(pos_2[:,0], '--', color='#000000', linewidth=1)
    plt.plot(pos_2[:,1], '--', color='#000000', linewidth=1)
    plt.plot(pos_2[:,2], '--', color='#000000', linewidth=1)
    plt.xlabel('Frame')
    plt.ylabel('Position [m]')
    plt.legend()
    ax = plt.gca()
    ax.set_ylim([-0.9, 1.6])
    plt.tight_layout()
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/coords_2.eps'), bbox_inches='tight', dpi=160, pad_inches=0)
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/coords_2.png'), bbox_inches='tight', dpi=160, pad_inches=0)

    error_1 = ((pos_1_optim - pos_1)**2).sum(dim=1)**0.5
    error_2 = ((pos_2_optim - pos_2)**2).sum(dim=1)**0.5
    fig = plt.figure(figsize=figsize)
    plt.plot(error_2, label='Back')
    plt.plot(error_1, label='Front')
    plt.xlabel('Frame')
    plt.ylabel('Error [m]')
    plt.legend()
    plt.tight_layout()
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/error.eps'), bbox_inches='tight', dpi=160, pad_inches=0)
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/error.png'), bbox_inches='tight', dpi=160, pad_inches=0)

    # Images
    for i, img in enumerate(img_list):
        img = img[:,128+64:1024-128-64,:]
        save_image(img, repo_path(f'samples/wacv/tracking/armadillos/render/render_{i}.jpg'))

    fig = plt.figure()
    plt.imshow(transient_gt_list[0])
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/transient.png'), bbox_inches='tight', dpi=160)

    fig = plt.figure()
    plt.imshow(transient_gt_noisy_list[0])
    fig.savefig(repo_path(f'samples/wacv/tracking/armadillos/transient_noisy.png'), bbox_inches='tight', dpi=160)

def export():
    data = torch.load(repo_path(f'samples/wacv/pose/armadillos/data.pth'))
    pos_1 = data['pos_1'] # 32x3
    pos_2 = data['pos_2'] # 32x3
    pos_1_optim = data['pos_1_optim'] # 32x3
    pos_2_optim = data['pos_2_optim'] # 32x3
    rotation_lists_ref = data['rotation_lists_ref']
    rotation_lists = data['rotation_lists']
    quaternion_1 = torch.stack(
        [rotation_lists_ref[i][0].quaternion[0].detach().cpu() for i in range(32)],
        dim=0) # 32x4
    quaternion_1 = quaternion_1 / (quaternion_1**2).sum(dim=1,keepdim=True)**0.5
    quaternion_2 = torch.stack(
        [rotation_lists_ref[i][1].quaternion[0].detach().cpu() for i in range(32)],
        dim=0) # 32x4
    quaternion_2 = quaternion_2 / (quaternion_2**2).sum(dim=1,keepdim=True)**0.5
    quaternion_1_optim = torch.stack(
        [rotation_lists[i][0].quaternion[0].detach().cpu() for i in range(32)],
        dim=0) # 32x4
    quaternion_1_optim = quaternion_1_optim / (quaternion_1_optim**2).sum(dim=1,keepdim=True)**0.5
    quaternion_2_optim = torch.stack(
        [rotation_lists[i][1].quaternion[0].detach().cpu() for i in range(32)],
        dim=0) # 32x4
    quaternion_2_optim = quaternion_2_optim / (quaternion_2_optim**2).sum(dim=1,keepdim=True)**0.5

    pos_1 = [pos_1[i].tolist() for i in range(32)]
    pos_2 = [pos_2[i].tolist() for i in range(32)]
    pos_1_optim = [pos_1_optim[i].tolist() for i in range(32)]
    pos_2_optim = [pos_2_optim[i].tolist() for i in range(32)]
    quaternion_1 = [quaternion_1[i].tolist() for i in range(32)]
    quaternion_2 = [quaternion_2[i].tolist() for i in range(32)]
    quaternion_1_optim = [quaternion_1_optim[i].tolist() for i in range(32)]
    quaternion_2_optim = [quaternion_2_optim[i].tolist() for i in range(32)]

    import json

    data = {
        'Armadillo1': {
            'Pos': pos_1_optim,
            'PosGT': pos_1,
            'Rot': quaternion_1_optim,
            'RotGT': quaternion_1,
        },
        'Armadillo2': {
            'Pos': pos_2_optim,
            'PosGT': pos_2,
            'Rot': quaternion_2_optim,
            'RotGT': quaternion_2,
        },
    }

    json_string = json.dumps(data, indent=4)

    with open('armadillo_data.json', 'w') as outfile:
        outfile.write(json_string)


make_data()
make_figures()
# export()