# -*- coding: utf-8 -*-
import glob
import os
import warnings
import json
import random

import hydra
import lpips as lpips_lib
import numpy as np
import torch
import wandb
from lightning.fabric import Fabric
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch.nn.utils import clip_grad
from torch.utils.data import DataLoader
import torch.distributed as dist
import deepspeed

from datasets.dataset_factory import get_dataset
from eval import evaluate_dataset
from gaussian_renderer import render_predicted
from scene.gaussian_predictor import GaussianSplatPredictor
from utils.general_utils import safe_state
from utils.loss_utils import l1_loss
from utils.loss_utils import l2_loss
from utils.loss_utils import ssim as ssim_fn


def count_parameters(model):
    trainable_params = 0
    frozen_params = 0

    for param in model.parameters():
        if param.requires_grad:
            trainable_params += param.numel()
        else:
            frozen_params += param.numel()

    return trainable_params / 1024 / 1024, frozen_params / 1024 / 1024

def optimizer_to(optim, device):
    # move optimizer to device
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

@hydra.main(version_base=None, config_path='configs', config_name="default_config")
def main(cfg: DictConfig):

    torch.set_float32_matmul_precision('high')
    if cfg.general.mixed_precision:
        fabric = Fabric(accelerator="cuda", devices=cfg.general.num_devices, strategy="deepspeed",
                        precision="bf16-mixed")
    else:
        fabric = Fabric(accelerator="cuda", devices=cfg.general.num_devices, strategy="deepspeed")
    print(f'Use {cfg.general.num_devices} GPUS')
    fabric.launch()
    if cfg.opt.resume != 'None':
        vis_dir = cfg.opt.resume
        print('resume training from: ', vis_dir)
    else:
        vis_dir = os.getcwd()
    if fabric.is_global_zero:
        # torch.autograd.set_detect_anomaly(True)
        dict_cfg = OmegaConf.to_container(
            cfg, resolve=True, throw_on_missing=True
        )
        
        if os.path.isdir(os.path.join(vis_dir, "wandb")):
            run_name_path = glob.glob(os.path.join(vis_dir, "wandb", "latest-run", "run-*"))[0]
            print("Got run name path {}".format(run_name_path))
            run_id = os.path.basename(run_name_path).split("run-")[1].split(".wandb")[0]
            print("Resuming run with id {}".format(run_id))

            wandb_run = wandb.init(project=cfg.wandb.project, resume=True,
                                    id=run_id, config=dict_cfg)
        else:
            mode = 'online'
            wandb_run = wandb.init(project=cfg.wandb.project, reinit=True,
                                   config=dict_cfg, mode=mode, name=cfg.wandb.name)

    first_iter = 0
    device = safe_state(cfg)

    gaussian_predictor = GaussianSplatPredictor(cfg)
    if fabric.is_global_zero:
        trainable, frozen = count_parameters(gaussian_predictor)

        print(f"Trainable parameters (MB): {trainable}")
        print(f"Frozen parameters (MB): {frozen}")
    gaussian_predictor = gaussian_predictor.to(memory_format=torch.channels_last)

    l = []
    
    l.append({'params': gaussian_predictor.stage2_net.parameters(),
                  'lr': cfg.opt.base_lr})
    optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15,
                                 betas=cfg.opt.betas)
    
    # set equal bs for log all experiments with different bs at the same step
    if cfg.data.category == "objaverse":
        equ_bs = 16
    else:
        equ_bs = 8

    if cfg.opt.loss == "l2":
        loss_fn = l2_loss
    elif cfg.opt.loss == "l1":
        loss_fn = l1_loss

    if cfg.opt.lambda_lpips != 0:
        lpips_fn = fabric.to_device(lpips_lib.LPIPS(net='vgg'))
    lambda_lpips = cfg.opt.lambda_lpips
    lambda_l12 = 1.0 - lambda_lpips
    lambda_ssim = cfg.opt.lambda_ssim
    lambda_mask = cfg.opt.lambda_mask

    bg_color = [1, 1, 1] if cfg.data.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32)
    background = fabric.to_device(background)

    if cfg.data.category in ["nmr", "objaverse"]:
        # num_workers = 0
        # persistent_workers = False
        num_workers = cfg.data.data_workers
        persistent_workers = True
    else:
        num_workers = 0
        persistent_workers = False

    dataset = get_dataset(cfg, "train")
    dataloader = DataLoader(dataset,
                            batch_size=cfg.opt.batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            persistent_workers=persistent_workers)

    val_dataset = get_dataset(cfg, "val")
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=1,
                                persistent_workers=True,
                                pin_memory=True)

    test_dataset = get_dataset(cfg, "vis")
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=True)

    # distribute model and training dataset
    gaussian_predictor, optimizer = fabric.setup(
        gaussian_predictor, optimizer
    )
    dataloader = fabric.setup_dataloaders(dataloader)

    ### Resume training
    ckpt_save_dict = {
                        "optimizer": optimizer,
                        "model": gaussian_predictor
                    }
    if cfg.opt.resume != 'None' and cfg.opt.resume != '':
        print('Loading an existing model from ', os.path.join(vis_dir, f"model_{cfg.opt.resume_epoch}"))
        fabric.load(os.path.join(vis_dir, f"model_{cfg.opt.resume_epoch}"), ckpt_save_dict) 
        with open(os.path.join(vis_dir, f"model_{cfg.opt.resume_epoch}.json"), "r") as j:
            model_info = json.load(j)
        first_iter = model_info['iteration']
        best_PSNR = model_info["best_PSNR"]
    elif cfg.opt.pretrained_ckpt is not None:
        fabric.load(os.path.join(cfg.opt.pretrained_ckpt, f"model_{cfg.opt.resume_epoch}"), ckpt_save_dict) 
        with open(os.path.join(cfg.opt.pretrained_ckpt, f"model_{cfg.opt.resume_epoch}.json"), "r") as j:
            model_info = json.load(j)
        best_PSNR = model_info["best_PSNR"]
    else:
        best_PSNR = 0.0


    print('Loaded model from iter: ', first_iter)
    gaussian_predictor.train()

    print("Beginning training")
    first_iter += 1
    first_iter = int(first_iter)
    iteration = int(first_iter * 16 / (cfg.opt.batch_size * cfg.general.num_devices))

    for num_epoch in range((cfg.opt.iterations + 1 - first_iter) // len(dataloader) + 1):
        dataloader.sampler.set_epoch(num_epoch)
        if cfg.data.use_random_num_input:
            n_view = random.randint(2, 8)
            setattr(cfg.data, 'input_images', n_view)
        for data in dataloader:
            
            iteration += 1
            # deepspeed.logger.info('iteration', iteration, int(cfg.logging.ckpt_iterations * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)), (iteration + 1) % int(cfg.logging.ckpt_iterations * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)))
            # if fabric.is_global_zero:
            log_step = int(iteration * cfg.opt.batch_size * cfg.general.num_devices / equ_bs)

            # print("starting iteration {} on process {}".format(iteration, fabric.global_rank))

            # =============== Prepare input ================
            rot_transform_quats = data["source_cv2wT_quat"][:, :cfg.data.input_images]
            
            input_intrinsics = None
            input_images = data["norm_imgs"][:, :cfg.data.input_images, ...]
            if cfg.data.use_mask:
                input_masks = data["fg_masks"][:, :cfg.data.coarse_stage_input_images, ...]
            else:
                input_masks = None
            if cfg.data.mod_camera_dec:
                input_cameras = data["source_camera"][:, :cfg.data.input_images, ...]
            else:
                input_cameras = None
            if cfg.data.use_plucker_emb:
                plucker_emb = data["plucker_emb"][:, :cfg.data.input_images, ...]
            else:
                plucker_emb = None

            gaussian_splats_list = gaussian_predictor(input_images,
                                                      input_masks,
                                                      input_intrinsics,
                                                      data["view_to_world_transforms"][:, :cfg.data.input_images, ...],
                                                      rot_transform_quats,
                                                      input_cameras=input_cameras,
                                                      plucker_emb=plucker_emb,
                                                      unnorm_imges=data["gt_images"][:, :cfg.data.input_images, ...],
                                                      source_cameras_view_to_world_coarse=data["view_to_world_transforms"][:, :cfg.data.input_images, ...],
                                                      source_cv2wT_quat_coarse=data["source_cv2wT_quat"][:, :cfg.data.input_images, ...])

            
            total_loss = 0
            rgb_loss_stage1_ls = []
            l12_loss_sum_ls = []
            mask_loss_sum_ls = []
            lpips_loss_sum_ls = []
            ssim_loss_sum_ls = []

            cnt = 0 
            for gaussian_splats in gaussian_splats_list:
                cnt += 1
                ### test nan
                for key, value in gaussian_splats.items():
                    if torch.isnan(value).any():
                        # Synchronize all processes
                        fabric.barrier()
                        # Only the first process executes the exit
                        if fabric.is_global_zero:
                            # Print a message or perform any cleanup if needed
                            print(key, "NAN is detected. Stopping all processes.")
                        # Gracefully terminate all processes
                        dist.destroy_process_group()
                        exit()

                rgb_loss_stage1 = 0.0
                l12_loss_sum = 0.0
                lpips_loss_sum = 0.0
                ssim_loss_sum = 0.0
                mask_loss_sum = 0.0
                rendered_images = []
                gt_images = []
                gt_masks = []
                pred_masks = []

                for b_idx in range(data["gt_images"].shape[0]):
                    # image at index 0 is training, remaining images are targets
                    # Rendering is done sequentially because gaussian rasterization code
                    # does not support batching
                    gaussian_splat_batch = {k: v[b_idx].contiguous() for k, v in gaussian_splats.items()}
                    for r_idx in range(cfg.data.input_images, data["gt_images"].shape[1]):
                        if "focals_pixels" in data.keys():
                            focals_pixels_render = data["focals_pixels"][b_idx, r_idx].cpu()
                        else:
                            focals_pixels_render = None

                        # stage 2 rendered image
                        render_result = render_predicted(gaussian_splat_batch,
                                                 data["world_view_transforms"][b_idx, r_idx],
                                                 data["full_proj_transforms"][b_idx, r_idx],
                                                 data["camera_centers"][b_idx, r_idx],
                                                 background,
                                                 cfg,
                                                 focals_pixels=focals_pixels_render)
                        # Put in a list for a later loss computation
                        image = render_result["render"]
                        rendered_images.append(image)
                        gt_image = data["gt_images"][b_idx, r_idx]
                        gt_images.append(gt_image)
                        if cfg.model.masked_loss:
                            gt_mask = data["fg_masks"][b_idx, r_idx]
                            gt_masks.append(gt_mask)
                            pred_mask = render_result['alpha']
                            pred_masks.append(pred_mask)
                rendered_images = torch.stack(rendered_images, dim=0)
                gt_images = torch.stack(gt_images, dim=0)
                if cfg.model.masked_loss:
                    gt_masks = torch.stack(gt_masks, dim=0)
                    pred_masks = torch.stack(pred_masks, dim=0)

                l12_loss_sum = loss_fn(rendered_images, gt_images)
                l12_loss_sum_ls.append(l12_loss_sum)
                if cfg.model.masked_loss:
                    mask_loss_sum = loss_fn(pred_masks, gt_masks)
                    mask_loss_sum_ls.append(mask_loss_sum)
                # lpips loss
                if cfg.opt.lambda_lpips != 0:
                    lpips_loss_sum = torch.mean(
                        lpips_fn(rendered_images * 2 - 1, gt_images * 2 - 1),
                    )
                lpips_loss_sum_ls.append(lpips_loss_sum)
                # ssim loss
                if cfg.opt.lambda_ssim != 0:
                    ssim_loss_sum = torch.mean(
                        ssim_fn(rendered_images, gt_images),
                    )
                ssim_loss_sum_ls.append(ssim_loss_sum)
                if cfg.opt.weight_layer_loss and cnt == 1:
                    total_loss += 0.1 * (l12_loss_sum * lambda_l12 + mask_loss_sum * lambda_mask + lpips_loss_sum * lambda_lpips + ssim_loss_sum * lambda_ssim)
                else:
                    total_loss += 1.0 * (l12_loss_sum * lambda_l12 + mask_loss_sum * lambda_mask + lpips_loss_sum * lambda_lpips + ssim_loss_sum * lambda_ssim)
                
            assert not total_loss.isnan(), "Found NaN loss!"
            # print("finished forward {} on process {}".format(iteration, fabric.global_rank))
            fabric.backward(total_loss)
            clip_grad.clip_grad_norm_(gaussian_predictor.parameters(), cfg.opt.max_grad_norm)
            for param_group in optimizer.param_groups:
                for param in param_group["params"]:
                    clip_grad.clip_grad_norm_(param, cfg.opt.max_grad_norm)
            # ============ Optimization ===============
            optimizer.step()
            optimizer.zero_grad()


            gaussian_predictor.eval()

            # ========= Logging =============
            with torch.no_grad():

                if (iteration % int(cfg.logging.loss_log * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0 or iteration < 50) and fabric.is_global_zero:
                    # Get current learning rates from the optimizer
                    current_lr = optimizer.param_groups[0]['lr']
                    wandb.log({'learning_rate': current_lr}, step=log_step)
                    wandb.log({"training_loss": np.log10(total_loss.item() + 1e-8)}, step=log_step)
                    for num_lay in range(len(l12_loss_sum_ls)):

                        layer_idx = num_lay + 1
                        wandb.log({f"training_l12_loss_layer{layer_idx}": np.log10(l12_loss_sum_ls[num_lay].item() + 1e-8)}, step=log_step)
                        if cfg.model.masked_loss:
                            wandb.log({f"training_mask_loss{layer_idx}": np.log10(mask_loss_sum_ls[num_lay].item() + 1e-8)}, step=log_step)
                        if cfg.opt.lambda_lpips != 0:
                            wandb.log({f"training_lpips_loss{layer_idx}": np.log10(lpips_loss_sum_ls[num_lay].item() + 1e-8)}, step=log_step)
                        if cfg.opt.lambda_ssim != 0:
                            wandb.log({f"training_ssim_loss{layer_idx}": np.log10(ssim_loss_sum_ls[num_lay].item() + 1e-8)}, step=log_step)
                        

                if (iteration % int(cfg.logging.render_log * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0 or iteration == 1) and fabric.is_global_zero:
                    wandb.log({"render": wandb.Image(image.clamp(0.0, 1.0).permute(1, 2, 0).detach().cpu().numpy())}, step=log_step)
                    wandb.log({"gt": wandb.Image(gt_image.permute(1, 2, 0).detach().cpu().numpy())}, step=log_step)
                # if 1:
                if (iteration % int(cfg.logging.loop_log * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0 or iteration < 5) and fabric.is_global_zero:
                    # torch.cuda.empty_cache()
                    try:
                        vis_data = next(test_iterator)
                    except UnboundLocalError:
                        test_iterator = iter(test_dataloader)
                        vis_data = next(test_iterator)
                    except StopIteration or UnboundLocalError:
                        test_iterator = iter(test_dataloader)
                        vis_data = next(test_iterator)

                    vis_data = {k: fabric.to_device(v) for k, v in vis_data.items()}
                    rot_transform_quats = vis_data["source_cv2wT_quat"][:, :cfg.data.input_images]
                    input_intrinsics = None
                    input_images = vis_data["norm_imgs"][:, :cfg.data.input_images, ...]
                    if cfg.data.use_mask:
                        input_masks = vis_data["fg_masks"][:, :cfg.data.coarse_stage_input_images, ...]
                    else:
                        input_masks = None
                    if cfg.data.mod_camera_dec:
                        input_cameras = vis_data["source_camera"][:, :cfg.data.input_images, ...]
                    else:
                        input_cameras = None
                    if cfg.data.use_plucker_emb:
                        plucker_emb = vis_data["plucker_emb"][:, :cfg.data.input_images, ...]
                    else:
                        plucker_emb = None

                    gaussian_splats_vis = gaussian_predictor(input_images,
                                                             input_masks,
                                                             input_intrinsics,
                                                             vis_data["view_to_world_transforms"][:, :cfg.data.input_images, ...],
                                                             rot_transform_quats,
                                                             input_cameras=input_cameras,
                                                             plucker_emb=plucker_emb,
                                                             unnorm_imges=vis_data["gt_images"][:, :cfg.data.input_images, ...],
                                                             source_cameras_view_to_world_coarse=vis_data["view_to_world_transforms"][:, :cfg.data.input_images, ...],
                                                             source_cv2wT_quat_coarse=vis_data["source_cv2wT_quat"][:, :cfg.data.input_images, ...]
                                                             )

                    test_loop = []
                    test_loop_gt = []
                    for r_idx in range(vis_data["gt_images"].shape[1]):
                        # We don't change the input or output of the network, just the rendering cameras
                        if "focals_pixels" in vis_data.keys():
                            focals_pixels_render = vis_data["focals_pixels"][0, r_idx]
                        else:
                            focals_pixels_render = None
                        test_image = render_predicted({k: v[0].contiguous() for k, v in gaussian_splats_vis[-1].items()},
                                                      vis_data["world_view_transforms"][0, r_idx],
                                                      vis_data["full_proj_transforms"][0, r_idx],
                                                      vis_data["camera_centers"][0, r_idx],
                                                      background,
                                                      cfg,
                                                      focals_pixels=focals_pixels_render)["render"]
                        test_loop_gt.append((np.clip(vis_data["gt_images"][0, r_idx].detach().cpu().numpy(), 0, 1)*255).astype(np.uint8))
                        test_loop.append((np.clip(test_image.detach().cpu().numpy(), 0, 1)*255).astype(np.uint8))

                    wandb.log({"rot": wandb.Video(np.asarray(test_loop), fps=20, format="mp4")},
                              step=log_step)
                    wandb.log({"rot_gt": wandb.Video(np.asarray(test_loop_gt), fps=20, format="mp4")},
                              step=log_step)

            fnames_to_save = []
            ckpt_save_dict = {
                    "optimizer": optimizer,
                    "model": gaussian_predictor
                }
            model_info_dict = {
                    "iteration": log_step,
                    "loss": total_loss.item(),
                    "best_PSNR": best_PSNR
                }
            # Find out which models to save
            if (iteration + 1) % int(cfg.logging.ckpt_iterations * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0:
                fname_to_save = f"model_{log_step}"
                fabric.barrier()
                fabric.save(os.path.join(vis_dir, fname_to_save), ckpt_save_dict)
                # Save the data to a JSON file
                with open(os.path.join(vis_dir, f'{fname_to_save}.json'), "w") as json_file:
                    json.dump(model_info_dict, json_file)
                fabric.barrier()
                
            if (iteration + 1) % int(cfg.logging.latest_ckpt_iterations * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0:
                fname_to_save = "model_latest"
                fabric.barrier()
                fabric.save(os.path.join(vis_dir, fname_to_save), ckpt_save_dict)
                # Save the data to a JSON file
                with open(os.path.join(vis_dir, f'{fname_to_save}.json'), "w") as json_file:
                    json.dump(model_info_dict, json_file)
                fabric.barrier()
            if (iteration + 1) % int(cfg.logging.val_log * equ_bs / (cfg.opt.batch_size * cfg.general.num_devices)) == 0 and fabric.is_global_zero:
            # if 1:
                torch.cuda.empty_cache()
                print("\n[ITER {}] Validating".format(log_step+1))
                
                scores = evaluate_dataset(
                    gaussian_predictor,
                    val_dataloader,
                    device=device,
                    model_cfg=cfg)
                wandb.log(scores, step=log_step+1)
                # save models - if the newest psnr is better than the best one,
                # overwrite best_model. Always overwrite the latest model.
                if scores["PSNR_novel"] > best_PSNR:
                    best_PSNR = scores["PSNR_novel"]
                    ckpt_save_dict = {
                        "model_state_dict": gaussian_predictor.state_dict(),
                        "loss": total_loss.item(),
                        "best_PSNR": best_PSNR,
                        "iteration": log_step
                    }
                    torch.save(ckpt_save_dict, os.path.join(vis_dir, 'model_best.pth'))
                    
                    print("\n[ITER {}] Saving new best checkpoint PSNR:{:.2f}".format(
                        log_step + 1, best_PSNR))
                torch.cuda.empty_cache()

            gaussian_predictor.train()
            if iteration < first_iter + 2 or iteration == 100:
                print('success!')

    wandb_run.finish()


if __name__ == "__main__":
    main()
