# -*- coding: utf-8 -*-
import glob
import json
import math
import os

import numpy as np
import torch
import torchvision
from PIL import Image
from torchvision import transforms

from .objaverse_new_render import ObjaverseDataset
from utils.graphics_utils import getProjectionMatrix
from utils.camera_utils import get_rays
from utils.image_utils import make_normalize_transform

GSO_ROOT_FIX = None  # Change this to your data directory
GSO_ROOT_RANDOM = None  # Change this to your data directory


class GSODataset(ObjaverseDataset):
    def __init__(self,
                 cfg,
                 dataset_name="test",
                 total_view_input=4,
                 total_view_sup=32
                 ) -> None:

        super(GSODataset).__init__()

        self.cfg = cfg
        self.root_dir_input = GSO_ROOT_FIX
        self.root_dir_sup = GSO_ROOT_RANDOM
        self.total_view_input = total_view_input
        self.total_view_sup = total_view_sup
        assert dataset_name != "train", "No training on GSO dataset!"

        self.dataset_name = dataset_name

        self.normalize = transforms.Compose(
            [
                make_normalize_transform(),
            ]
        )
        self.resize = transforms.Resize((self.cfg.data.training_resolution,self.cfg.data.training_resolution))

        self.projection_matrix = getProjectionMatrix(
            znear=self.cfg.data.znear, zfar=self.cfg.data.zfar,
            fovX=cfg.data.fov * 2 * np.pi / 360,
            fovY=cfg.data.fov * 2 * np.pi / 360).transpose(0, 1)

        self.image_side_target = self.cfg.data.training_resolution
        self.opengl_to_colmap = torch.tensor([[1,  0,  0,  0],
                                              [0, -1,  0,  0],
                                              [0,  0, -1,  0],
                                              [0,  0,  0,  1]], dtype=torch.float32)

        self.object_ids = sorted(os.listdir(self.root_dir_input))

        print('============= length of dataset %d =============' % len(self.object_ids))

        self.test_input_idxs = [0]

    def __len__(self):
        return len(self.object_ids)

    def get_example_id(self, index):
        example_path = self.object_ids[index]
        return os.path.basename(example_path)

    def __getitem__(self, index):
        
        if self.dataset_name == "vis" or self.dataset_name == "val_vis_test":
            images_and_camera_poses = self.load_loop(index, 100)
        else:
            num_views = self.total_view_sup + self.total_view_input
            images_and_camera_poses = self.load_imgs_and_convert_cameras(index, num_views)


        images_and_camera_poses = self.make_poses_relative_to_first(images_and_camera_poses)
        images_and_camera_poses["source_cv2wT_quat"] = self.get_source_cw2wT(images_and_camera_poses["view_to_world_transforms"])
        if self.cfg.data.use_plucker_emb:
            plucker_embs = []
            for input_idx in range(self.cfg.data.input_images):
                rays_o, rays_d = get_rays(images_and_camera_poses["view_to_world_transforms"][input_idx], self.cfg.data.training_resolution, self.cfg.data.training_resolution, self.cfg.data.fov, opengl=False) # [h, w, 3]
                plucker_emb = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
                plucker_embs.append(plucker_emb)
                 
            plucker_embs = torch.stack(plucker_embs, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
            images_and_camera_poses["plucker_emb"] = plucker_embs
        return images_and_camera_poses
