"""Rbf Reconstruction of a rendered Bunny"""
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pytorch3d.io import save_obj, load_obj
from torchvision.utils import save_image
from totri.types import VolumeDef, BinDef
from totri.data.util import make_wall_grid, repo_path
from totri.data.mesh import load_sample_mesh
from totri.fitting.rbf import GaussianRbfBase, RbfFitting
from totri.util import UnitDepthMapRender, faces_of_verts
from totri.util.render import UnitCubeMeshRender
from totri.render import MeshRenderConfocal

def make_data():
    torch.manual_seed(42)
    os.makedirs(repo_path(f'samples/wacv/bunny/rbf'), exist_ok=True)

    # Load Mesh
    bin_def=BinDef(0.02, 0, 200, unit_is_time=False)
    resolution=64
    scan_points = make_wall_grid(
        -1, 1, resolution,
        -1, 1, resolution,
        z_val=0)
    verts, faces = load_sample_mesh('bunny')

    # Render
    transient_gt = MeshRenderConfocal.apply(
        verts, faces, scan_points[None],
        bin_def, None, None)[0]

    # Reconstruction
    resolution=64
    volume_def=VolumeDef([-1, -1, 0], [1, 1, 2], [resolution,]*3)
    torch.manual_seed(42)
    tic = time.perf_counter()
    with open(repo_path('samples/wacv/bunny/rbf/runtimes.txt'), "w") as file:
        file.write("Iteration and Runtime\n")
    def callback(i, base, background_model):
        i = i+1
        if (i < 100 and i % 5 == 0) or (i >= 100 and i % 10 == 0):
            toc = time.perf_counter()
            runtime = toc - tic
            print(f'Callback {i} after {toc - tic} s')
            verts = base.make_verts()
            save_obj(repo_path(f'samples/wacv/bunny/rbf/{i}.obj'), verts, faces_of_verts(verts))
            with open(repo_path('samples/wacv/bunny/rbf/runtimes.txt'), "a") as file:
                file.write(f"{i}, {1000*runtime}\n")
    base = RbfFitting(
        bin_def,
        scan_points,
    ).fit(
        GaussianRbfBase(
            volume_def,
            has_color=False,
            sigma_init=0.05),
        transient_gt,
        num_iter=500,
        callback=callback,
    )

def make_metrics():
    original_dir = repo_path('samples/wacv/bunny/original')
    rbf_dir = repo_path('samples/wacv/bunny/rbf')
    samples = list(range(5, 100, 5)) + list(range(100, 510, 10))

    # Load Mesh
    verts_gt, faces_gt = load_sample_mesh('bunny')

    # Render depth map
    depth_map_renderer = UnitDepthMapRender(256, 2, device="cuda")
    depth_gt, mask_gt = depth_map_renderer.apply(verts_gt[0], faces_gt[0])

    mae_rbf = []
    iou_rbf = []
    for i in samples:
        verts, faces, aux = load_obj(os.path.join(rbf_dir, f"{i}.obj"))
        faces = faces.verts_idx
        verts = verts.to(device="cuda")
        faces = faces.to(device="cuda")
        depth, mask = depth_map_renderer.apply(verts, faces)
        mask_intersection = (mask & mask_gt)
        # Depth error
        depth_error = depth - depth_gt
        # depth_error = (depth - depth_gt) / depth_gt
        # mae_rbf.append(depth_error[mask_intersection].abs().mean().item())
        mae_rbf.append(depth_error[mask_intersection].abs().mean().item())
        # IOU
        intersection = (mask & mask_gt).sum().item()
        union = (mask | mask_gt).sum().item()
        iou_rbf.append(intersection / union)
    
    mae_original = []
    iou_original = []
    for i in samples:
        verts, faces, aux = load_obj(os.path.join(original_dir, f"{i:06d}.obj"))
        faces = faces.verts_idx
        verts = verts.to(device="cuda")
        faces = faces.to(device="cuda")
        depth, mask = depth_map_renderer.apply(verts, faces)
        mask_intersection = (mask & mask_gt)
        # Depth error
        depth_error = depth - depth_gt
        # depth_error = (depth - depth_gt) / depth_gt
        # mae_original.append(depth_error[mask_intersection].abs().mean().item())
        mae_original.append(depth_error[mask_intersection].abs().mean().item())
        # IOU
        intersection = (mask & mask_gt).sum().item()
        union = (mask | mask_gt).sum().item()
        iou_original.append(intersection / union)

    runtimes_rbf = []
    with open(os.path.join(rbf_dir, "runtimes.txt"), "r") as file:
        lines = file.readlines()
    for line in lines:
        if "," not in line:
            continue
        it, runtime = line.split(",")[:2]
        if int(it.strip()) in samples:
            runtimes_rbf.append(float(runtime.strip()) / 1000 / 60)
        if int(it.strip()) == 5:
            print("mae_rbf, rbf 10s", mae_rbf[len(runtimes_rbf)-1])
        if int(it.strip()) == 15:
            print("mae_rbf, rbf 1min", mae_rbf[len(runtimes_rbf)-1])
        if int(it.strip()) == 130:
            print("mae_rbf, rbf 10min", mae_rbf[len(runtimes_rbf)-1])
        if int(it.strip()) == 500:
            print("mae_rbf, rbf 45min", mae_rbf[len(runtimes_rbf)-1])


    runtimes_original = []
    with open(os.path.join(original_dir, "log.txt"), "r") as file:
        lines = file.readlines()
    for line in lines:
        if "," not in line:
            continue
        it, runtime = line.split(",")[:2]
        if int(it.strip()) in samples:
            runtimes_original.append(float(runtime.strip()) / 1000 / 60)
        if int(it.strip()) == 5:
            print("mae_original, rbf 1min", mae_original[len(runtimes_original)-1])
        if int(it.strip()) == 15:
            print("mae_original, rbf 10min", mae_original[len(runtimes_original)-1])
        if int(it.strip()) == 35:
            print("mae_original, rbf 45min", mae_original[len(runtimes_original)-1])
        if int(it.strip()) == 610:
            print("mae_original, rbf 24h", mae_original[len(runtimes_original)-1])

    print('max', max(runtimes_original))

    scale = 0.75
    figsize = (scale*6, scale*3)
    fig, ax = plt.subplots(figsize=figsize)
    plt.semilogx(runtimes_rbf, mae_rbf, label="Ours")
    plt.semilogx(runtimes_original, mae_original, label="Iseringhausen and Hullin")
    plt.xlabel('Runtime [min]')
    plt.ylabel('Mean Absolute Error')
    plt.legend()
    plt.savefig(repo_path('samples/wacv/bunny/mae.png'), bbox_inches='tight', pad_inches=0)
    plt.savefig(repo_path('samples/wacv/bunny/mae.eps'), bbox_inches='tight', pad_inches=0)

    plt.figure(figsize=figsize)
    plt.semilogx(runtimes_rbf, iou_rbf, label="Ours")
    plt.semilogx(runtimes_original, iou_original, label="Iseringhausen and Hullin")
    plt.xlabel('Runtime [min]')
    plt.ylabel('Intersection over Union')
    plt.ylim((0,1))
    # plt.legend()
    plt.savefig(repo_path('samples/wacv/bunny/iou.png'), bbox_inches='tight', pad_inches=0)
    plt.savefig(repo_path('samples/wacv/bunny/iou.eps'), bbox_inches='tight', pad_inches=0)

    print("mae_rbf", mae_rbf[-1])
    print("mae_original", mae_original[-1])



def make_renders():
    original_dir = repo_path('samples/wacv/bunny/original')
    rbf_dir = repo_path('samples/wacv/bunny/rbf')
    times     = [  "10s",       "1min",      "10min",      "45min",        "24h"]
    rbf_files = ["5.obj",     "15.obj",    "130.obj",    "500.obj",         None]
    org_files = [   None, "000005.obj", "000015.obj", "000035.obj", "000610.obj"]

    umcr = UnitCubeMeshRender(resolution=2048, distance=0.2)
    for t, rbf, org in zip(times, rbf_files, org_files):
        if rbf is not None:
            verts, faces, aux = load_obj(os.path.join(rbf_dir, rbf))
            faces = faces.verts_idx
            verts = verts.to(device="cuda")
            faces = faces.to(device="cuda")
            img = umcr.apply(verts, faces)
            img = torch.nn.functional.interpolate(img[None], (512, 512), mode='area')[0]
            save_image(img, repo_path(f'samples/wacv/bunny/render_rbf_{t}.png'))
        if org is not None:
            verts, faces, aux = load_obj(os.path.join(original_dir, org))
            faces = faces.verts_idx
            verts = verts.to(device="cuda")
            faces = faces.to(device="cuda")
            img = umcr.apply(verts, faces)
            img = torch.nn.functional.interpolate(img[None], (512, 512), mode='area')[0]
            save_image(img, repo_path(f'samples/wacv/bunny/render_org_{t}.png'))


# make_data()
make_metrics()
# make_renders()