# -*- coding: utf-8 -*-
import glob
import json
import math
import os
import time

import numpy as np
import torch
import torchvision
from torchvision import transforms
from PIL import Image

from .shared_dataset import SharedDataset
from utils.camera_utils import get_loop_cameras, camera_normalization_objaverse, build_camera_principle, get_rays

from utils.graphics_utils import getProjectionMatrix

from utils.image_utils import make_normalize_transform

OBJAVERSE_ROOT = None  # Change this to your data directory
OBJAVERSE_LVIS_ANNOTATION_PATH = None

assert OBJAVERSE_ROOT is not None, "Update dataset path"
assert OBJAVERSE_LVIS_ANNOTATION_PATH is not None, "Update filtering .json path"


class ObjaverseDataset(SharedDataset):
    def __init__(self,
                 cfg,
                 dataset_name="train"
                 ) -> None:

        super(ObjaverseDataset).__init__()
        self.cfg = cfg
        self.root_dir = OBJAVERSE_ROOT

        # load the file names
        with open(OBJAVERSE_LVIS_ANNOTATION_PATH) as f:
            self.paths = json.load(f)
        train_rate = 99.9
        print('total number of training objects: ', len(self.paths))
        self.normalize = transforms.Compose(
            [
                make_normalize_transform(),
            ]
        )
        # split the dataset for training and validation
        total_objects = len(self.paths)
        self.dataset_name = dataset_name
        if self.dataset_name == "val" or dataset_name == "vis":
            # validation or visualisation on Objaverse
            self.paths = self.paths[math.floor(total_objects / 100. * train_rate):]  # used last 1% as validation
        elif self.dataset_name == "test":
            raise NotImplementedError  # Objaverse does not have separate test subset
        else:
            self.paths = self.paths[:math.floor(total_objects / 100. * train_rate)]  # used first 99% as training

        if cfg.data.subset != -1:
            self.paths = self.paths[:cfg.data.subset]

        print('============= length of dataset %d =============' % len(self.paths))

        self.projection_matrix = getProjectionMatrix(
            znear=self.cfg.data.znear, zfar=self.cfg.data.zfar,
            fovX=cfg.data.fov * 2 * np.pi / 360,
            fovY=cfg.data.fov * 2 * np.pi / 360).transpose(0, 1)

        self.image_side_target = self.cfg.data.training_resolution
        self.opengl_to_colmap = torch.tensor([[1,  0,  0,  0],
                                              [0, -1,  0,  0],
                                              [0,  0, -1,  0],
                                              [0,  0,  0,  1]], dtype=torch.float32)

        self.imgs_per_obj_train = self.cfg.opt.imgs_per_obj
        

    def __len__(self):
        return len(self.paths)

    def load_imgs_and_convert_cameras(self, paths, num_views):
        """
        Load the images, camera matrices and projection matrices for a given object
        """
        bg_color = torch.tensor([1., 1., 1.], dtype=torch.float32).unsqueeze(1).unsqueeze(2)
        world_view_transforms = []
        view_world_transforms = []
        w2c_source = []

        camera_centers = []
        imgs = []
        norm_imgs = []
        fg_masks = []

        # validation dataset is used for scoring - fix cond frame for reproducibility
        # in trainng need to randomly sample the conditioning frame
        if self.dataset_name != "train":
            indexes = torch.arange(num_views)
        else:
            indexes = torch.randperm(len(paths))[:num_views]
            indexes = torch.cat([indexes[:self.cfg.data.input_images], indexes], dim=0)
        
        # load the images and cameras
        temp = time.time()
        for i in indexes:
            # read to [0, 1] FloatTensor and resize to training_resolution
            img = Image.open(paths[i])
            # resize to the training resolution
            img = torchvision.transforms.functional.resize(img,
                                                           self.cfg.data.training_resolution,
                                                           interpolation=torchvision.transforms.InterpolationMode.LANCZOS)
            img = torchvision.transforms.functional.pil_to_tensor(img) / 255.0
            
            # set background
            fg_masks.append(img[3:, ...])
            

            imgs.append(img[:3, ...] * img[3:, ...] + bg_color * (1 - img[3:, ...]))
            norm_img = self.normalize(img[:3, ...] * img[3:, ...] + bg_color * (1 - img[3:, ...]))
            norm_imgs.append(norm_img)
            # .npy files store world-to-camera matrix in column major order
            w2c_cmo = torch.tensor(np.load(paths[i].replace('png', 'npy'))).float()  # 3x4
            w2c_source.append(w2c_cmo)
            w2c_cmo = torch.cat([w2c_cmo, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32)], dim=0)  # 4x4
            # camera poses in .npy files are in OpenGL convention:
            #     x right, y up, z into the camera (backward),
            # need to transform to COLMAP / OpenCV:
            #     x right, y down, z away from the camera (forward)
            w2c_cmo = torch.matmul(self.opengl_to_colmap, w2c_cmo)
            # need row major oder
            world_view_transform = w2c_cmo.transpose(0, 1)
            view_world_transform = w2c_cmo.inverse().transpose(0, 1)
            camera_center = view_world_transform[3, :3].clone()
           
            world_view_transforms.append(world_view_transform)
            view_world_transforms.append(view_world_transform)
            # full_proj_transforms.append(full_proj_transform)
            camera_centers.append(camera_center)
            
        imgs = torch.stack(imgs)
        norm_imgs = torch.stack(norm_imgs)
        fg_masks = torch.stack(fg_masks)
        w2c_source = torch.stack(w2c_source, dim=0) #(n_view+input_view, 3, 4)
        world_view_transforms = torch.stack(world_view_transforms) # (n_view+input_view, 4, 4)
        view_world_transforms = torch.stack(view_world_transforms)
        camera_centers = torch.stack(camera_centers)

        pps_pixels = torch.zeros((imgs.shape[0], 2))
        
        if self.cfg.data.mod_camera_dec:
            poses = camera_normalization_objaverse(normed_dist_to_center='auto', poses=w2c_source)
            intrinsics = torch.tensor([self.cfg.data.intrinsics[:2], self.cfg.data.intrinsics[2:], [128, 128]]).repeat(len(indexes), 1, 1)
            source_camera = build_camera_principle(poses, intrinsics)
        # fix the distance of the source camera to the object / world center
        assert torch.norm(camera_centers[0]) > 1e-5, \
            "Camera is at {} from center".format(torch.norm(camera_centers[0]))
        translation_scaling_factor = 2.0 / torch.norm(camera_centers[0])
        world_view_transforms[:, 3, :3] *= translation_scaling_factor
        view_world_transforms[:, 3, :3] *= translation_scaling_factor
        camera_centers *= translation_scaling_factor

        full_proj_transforms = world_view_transforms.bmm(self.projection_matrix.unsqueeze(0).expand(
            world_view_transforms.shape[0], 4, 4))
       
        data_dict = {"gt_images": imgs,
                    "norm_imgs": norm_imgs,
                    "w2c_source": w2c_source,
                    "world_view_transforms": world_view_transforms,
                    "view_to_world_transforms": view_world_transforms,
                    "full_proj_transforms": full_proj_transforms,
                    "camera_centers": camera_centers,
                    "pps_pixels": pps_pixels,
                    "fg_masks": fg_masks}
        if self.cfg.data.mod_camera_dec:
            data_dict['source_camera'] = source_camera
        
        return data_dict

    def load_loop(self, paths, num_imgs_in_loop):
        w2c_source = []
        world_view_transforms = []
        view_world_transforms = []
        camera_centers = []
        imgs = []
        norm_imgs = []
        fg_masks = []

        gt_imgs_and_cameras = self.load_imgs_and_convert_cameras(paths, len(paths))
        loop_cameras_c2w_cmo, camera_poses = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop)

        for src_idx in range(self.cfg.data.input_images):
            imgs.append(gt_imgs_and_cameras["gt_images"][src_idx])
            fg_masks.append(gt_imgs_and_cameras["fg_masks"][src_idx])
            camera_centers.append(gt_imgs_and_cameras["camera_centers"][src_idx])
            world_view_transforms.append(gt_imgs_and_cameras["world_view_transforms"][src_idx])
            view_world_transforms.append(gt_imgs_and_cameras["view_to_world_transforms"][src_idx])
            w2c_source.append(gt_imgs_and_cameras["w2c_source"][src_idx])

        for loop_idx in range(len(loop_cameras_c2w_cmo)):
            loop_camera_c2w_cmo = loop_cameras_c2w_cmo[loop_idx]
            view_world_transform = torch.from_numpy(loop_camera_c2w_cmo).transpose(0, 1)
            world_view_transform = torch.from_numpy(loop_camera_c2w_cmo).inverse().transpose(0, 1)
            camera_center = view_world_transform[3, :3].clone()

            camera_centers.append(camera_center)
            world_view_transforms.append(world_view_transform)
            view_world_transforms.append(view_world_transform)
            w2c_source.append(camera_poses[loop_idx])

            # use the closest camera as reference gt image
            closest_gt_idx = torch.argmin(torch.norm(
                gt_imgs_and_cameras["camera_centers"] - camera_center.unsqueeze(0), dim=-1)).item()
            imgs.append(gt_imgs_and_cameras["gt_images"][closest_gt_idx])
            norm_imgs.append(self.normalize(gt_imgs_and_cameras["gt_images"][closest_gt_idx]))
            fg_masks.append(gt_imgs_and_cameras["fg_masks"][closest_gt_idx])
        
        imgs = torch.stack(imgs)
        norm_imgs = torch.stack(norm_imgs)
        fg_masks = torch.stack(fg_masks)
        world_view_transforms = torch.stack(world_view_transforms)
        view_world_transforms = torch.stack(view_world_transforms)
        camera_centers = torch.stack(camera_centers)
        
        full_proj_transforms = world_view_transforms.bmm(self.projection_matrix.unsqueeze(0).expand(
            world_view_transforms.shape[0], 4, 4))

        pps_pixels = torch.zeros((imgs.shape[0], 2))

        if self.cfg.data.mod_camera_dec:
            w2c_source = torch.stack(w2c_source, dim=0)
            poses = camera_normalization_objaverse(normed_dist_to_center='auto', poses=w2c_source)
            intrinsics = torch.tensor([self.cfg.data.intrinsics[:2], self.cfg.data.intrinsics[2:], [128, 128]]).repeat(w2c_source.shape[0], 1, 1)
            source_camera = build_camera_principle(poses, intrinsics)
        data_dict = {"gt_images": imgs.to(memory_format=torch.channels_last),
                    "norm_imgs": norm_imgs,
                    "world_view_transforms": world_view_transforms,
                    "view_to_world_transforms": view_world_transforms,
                    "full_proj_transforms": full_proj_transforms,
                    "camera_centers": camera_centers,
                    "pps_pixels": pps_pixels,
                    "fg_masks": fg_masks}
        if self.cfg.data.mod_camera_dec:
            data_dict['source_camera'] = source_camera
        return data_dict

    def get_example_id(self, index):
        example_id = self.paths[index]
        return example_id

    def __getitem__(self, index):
        # load the rendered images
        filename = os.path.join(self.root_dir, self.paths[index])
        paths = glob.glob(filename + '/*.png')

        
        if self.dataset_name == "vis" or self.dataset_name == "val_vis_test":
            images_and_camera_poses = self.load_loop(paths, 100)
        else:
            if self.dataset_name == "train":
                if len(paths) < self.imgs_per_obj_train:
                    print(filename, f'has only {len(paths)} images rendered!')
                    return self.__getitem__(index+1)
                num_views = self.imgs_per_obj_train
            else:
                num_views = len(paths)
            try:
                images_and_camera_poses = self.load_imgs_and_convert_cameras(paths, num_views)
            except:
                return self.__getitem__(index+1)
        images_and_camera_poses = self.make_poses_relative_to_first(images_and_camera_poses)
        if self.cfg.data.use_plucker_emb:
            plucker_embs = []
            for input_idx in range(self.cfg.data.input_images):
                rays_o, rays_d = get_rays(images_and_camera_poses["view_to_world_transforms"][input_idx], self.cfg.data.training_resolution, self.cfg.data.training_resolution, self.cfg.data.fov, opengl=False) # [h, w, 3]
                plucker_emb = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
                plucker_embs.append(plucker_emb)
                 
            plucker_embs = torch.stack(plucker_embs, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
            images_and_camera_poses["plucker_emb"] = plucker_embs
        images_and_camera_poses["source_cv2wT_quat"] = self.get_source_cw2wT(images_and_camera_poses["view_to_world_transforms"])

        if images_and_camera_poses['fg_masks'].sum() == 0:
            with open('/comp_robot/jiaminwu/3D_reconstruction/multi-view-3dgs/bad_mask.txt', 'a+') as f:
                f.write(str(filename) + '\n')
            return self.__getitem__(index+1)
        # images_and_camera_poses["example_id"] = self.paths[index]
        return images_and_camera_poses
