import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pytorch3d.io import load_obj
from torchvision.utils import save_image
from totri.types import VolumeDef, BinDef
from totri.data import repo_path, Zaragoza, make_wall_grid
from totri.fitting.rbf import GaussianRbfBase, RbfFitting
from totri.util import UnitDepthMapRender, faces_of_verts
from totri.util.render import UnitCubeMeshRender
from totri.render import MeshRenderConfocal
from tomcubes import MarchingCubesFunction
from totri.util.format import verts_transform, faces_of_verts
from pytorch3d.io import save_obj

from torch.utils.tensorboard import SummaryWriter

class ResidualBlock(torch.nn.Module):

    def __init__(self, channels):
        super().__init__()
        self.convs = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels, channels, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels, channels, 1),
        )

    def forward(self, x):
        return x + self.convs(x)

class Reconstruction(torch.nn.Module):

    def __init__(self, channels, depth):
        super().__init__()
        self.convs = torch.nn.Sequential(
            torch.nn.PixelUnshuffle(2), # 8x8
            torch.nn.Conv2d(4 * channels, 1 * channels, 1,),
            torch.nn.ReLU(),
            torch.nn.PixelUnshuffle(2), # 4x4
            torch.nn.Conv2d(4 * channels, 1 * channels, 1),
            torch.nn.ReLU(),
            torch.nn.PixelUnshuffle(2), # 2x2
            torch.nn.Conv2d(4 * channels, 1 * channels, 1),
            torch.nn.ReLU(),
            torch.nn.PixelUnshuffle(2), # 1x1
            ResidualBlock(4 * channels),
            ResidualBlock(4 * channels),
            ResidualBlock(4 * channels),
            ResidualBlock(4 * channels),
            torch.nn.ReLU(),
            torch.nn.PixelShuffle(2), # 2x2
            torch.nn.Conv2d(1 * channels, 4 * channels, 1,),
            torch.nn.ReLU(),
            torch.nn.PixelShuffle(2), # 4x4
            torch.nn.Conv2d(1 * channels, 4 * channels, 1),
            torch.nn.ReLU(),
            torch.nn.PixelShuffle(2), # 8x8
            torch.nn.Conv2d(1 * channels, 4 * channels, 1),
            torch.nn.ReLU(),
            torch.nn.PixelShuffle(2), # 16x16
            torch.nn.Conv2d(channels, depth, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, transient):
        # transient B, T, H, W
        # volume    B, D, H, W
        volume = self.convs(transient)
        volume = volume - volume.flatten(1,-1).mean(dim=1)[:,None,None,None]
        volume = volume / (volume.flatten(1,-1).mean(dim=1)[:,None,None,None] + 1.e-6)
        return volume

def uniform(min_val, max_val):
    return torch.rand(()).item() * (max_val - min_val) + min_val

def get_volume_mask(resolution):
    volume_mask = torch.ones(
        resolution,
        dtype=torch.float32, device="cuda")
    volume_mask[ 0,:,:] = 0
    volume_mask[-1,:,:] = 0
    volume_mask[:, 0,:] = 0
    volume_mask[:,-1,:] = 0
    volume_mask[:,:, 0] = 0
    volume_mask[:,:,-1] = 0
    return volume_mask

def load_mesh(obj_path, device):
    verts, faces, aux = load_obj(obj_path)
    faces = faces.verts_idx
    return verts[None].to(device=device), faces[None].to(device=device)

def train():
    # Settings
    b = 32
    volume_resolution = 16
    scan_resolution = 16
    bin_def = BinDef(0.05, 0.0, 256)
    volume_def = VolumeDef([-0.5, -0.5, 0.1], [0.5, 0.5, 1.1], [volume_resolution,]*3)
    scan_points = make_wall_grid(
        -0.5, 0.5, scan_resolution,
        -0.5, 0.5, scan_resolution,
        z_val=0)[None]
    volume_mask = get_volume_mask(volume_def.resolution)
    model = Reconstruction(256, volume_resolution).cuda()

    # Init Training
    summary_writer = SummaryWriter(repo_path(f"samples/wacv/self_supervised/log"))
    optim = torch.optim.Adam(model.parameters(), 1.0e-5)
    loss_fn = torch.nn.L1Loss()

    # Training Loop
    for i in range(500000):
        optim.zero_grad()
            
        # Make sample
        transient = []
        for _ in range(b):
            rbf = GaussianRbfBase(volume_def, has_color=False, sigma_init=uniform(0.05, 0.1))
            for _ in range(torch.randint(1, 20, ()).item()):
                rbf.append(uniform(-0.4, 0.4), uniform(-0.4, 0.5), uniform(0.3, 0.9))
            verts_gt = rbf.make_verts()
            transient.append(MeshRenderConfocal.apply(verts_gt[None], None, scan_points, bin_def, None, None))
        transient = torch.cat(transient, dim=0)

        # Reconstruct volume
        volume = model(transient.view(b, bin_def.num, scan_resolution, scan_resolution))
        volume = volume * volume_mask[None] + (1-volume_mask[None]) * -1

        # Self Supervised Rendering
        verts_batch = MarchingCubesFunction.apply(volume[:,None], 0.0, volume_def.start, volume_def.end)
        transient_out = []
        for verts in verts_batch:
            transient_out.append(MeshRenderConfocal.apply(verts[None], None, scan_points, bin_def, None, None))
        transient_out = torch.cat(transient_out, dim=0)
    
        # Losses
        smoothness = 1.e-5 * torch.mean(torch.sqrt(
            (volume[:, 1:, :-1, :-1] - volume[:, :-1, :-1, :-1] )**2 +
            (volume[:, :-1, 1:, :-1] - volume[:, :-1, :-1, :-1] )**2 +
            (volume[:, :-1, :-1, 1:] - volume[:, :-1, :-1, :-1] )**2 +
            1.e-6
        ))
        supervision = loss_fn(transient, transient_out)
        loss = smoothness + supervision

        # Step
        loss.backward()
        optim.step()

        # Log
        if (i+1) % 100 == 0:
            print(f"{i}: {loss.item()} <smoothness {smoothness.item()}, supervision {supervision.item()}> [Verts {verts.shape[0]}/{verts_gt.shape[0]}]")
            summary_writer.add_scalar("loss/l1", loss, global_step=i)
        if (i+1) % 25000 == 0:
            summary_writer.add_mesh(
                "mesh_gt/verts",
                verts_transform(verts_gt[None]),
                global_step=i,
                faces=faces_of_verts(verts_gt[None]),
                colors=None)
            summary_writer.add_mesh(
                "mesh_out/verts",
                verts_transform(verts[None]),
                global_step=i,
                faces=faces_of_verts(verts[None]),
                colors=None)
    # Save Model
    torch.save(model.state_dict(), repo_path(f"samples/wacv/self_supervised/state_dict.pth"))

def test():
    # Settings
    b = 1
    volume_resolution = 16
    scan_resolution = 16
    bin_def = BinDef(0.05, 0.0, 256)
    volume_def = VolumeDef([-0.5, -0.5, 0.1], [0.5, 0.5, 1.1], [volume_resolution,]*3)
    scan_points = make_wall_grid(
        -0.5, 0.5, scan_resolution,
        -0.5, 0.5, scan_resolution,
        z_val=0)[None]
    volume_mask = get_volume_mask(volume_def.resolution)

    # Load Model
    model = Reconstruction(256, volume_resolution).cuda()
    model.load_state_dict(torch.load(repo_path(f"samples/wacv/self_supervised/state_dict.pth")))

    # Reconstruct letters
    for letter in "wacv":
        verts_gt, faces_gt = load_mesh(repo_path(f"samples/wacv/self_supervised/{letter}.obj", "cuda"))
        transient = MeshRenderConfocal.apply(verts_gt, faces_gt, scan_points, bin_def, None, None)
        volume = model(transient.view(b, bin_def.num, scan_resolution, scan_resolution))
        volume = volume * volume_mask[None] + (1-volume_mask[None]) * -1
        verts = MarchingCubesFunction.apply(volume[:,None], 0.0, volume_def.start, volume_def.end)[0]
        save_obj(repo_path(f"samples/wacv/self_supervised/log/{letter}_out.obj"), verts, faces_of_verts(verts))

train()
test()
