#import open3d as o3d
import trimesh
import mmcv
import numpy as np

from mmdet3d.core.points import BasePoints, get_points_type
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
import random
import os


@PIPELINES.register_module()
class LoadOccupancy(object):
    """Load multi channel images from a list of separate channel files.

    Expects results['img_filename'] to be a list of filenames.

    Args:
        to_float32 (bool): Whether to convert the img to float32.
            Defaults to False.
        color_type (str): Color type of the file. Defaults to 'unchanged'.
    """

    def __init__(self, to_float32=True, use_semantic=False, load_lidar='dense', is_kitti=False, is_test=False):
        self.to_float32 = to_float32
        self.use_semantic = use_semantic
        self.load_lidar = load_lidar
        self.is_kitti = is_kitti
        self.is_test = is_test

    
    def __call__(self, results):
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
            dict: The result dict containing the multi-view image data. \
                Added keys and values are described below.

                - filename (str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """

        occ_size_ori = [600, 600, 48]

   
        
        root_path = '/mnt/cfs/algorithm/temp/point_mesh_voxel/'
        '''
        rel_path = 'scene_{0}/cropped_voxel/voxel_{1}.ply'.format(results['scene_token'], results['lidar_token'])
        pcd = trimesh.load(root_path + rel_path)
        pcd_np = np.array(pcd.vertices).astype(np.int)
        #pcd_np = np.array(pcd.points).astype(np.int)
        occ_np = np.zeros(occ_size_ori).astype(np.int)
        for i in range(3):
            pcd_np[:, i][pcd_np[:, i] >= occ_size_ori[i]] = occ_size_ori[i] - 1    
        
        root_path = '/mnt/cfs/algorithm/temp/point_mesh_voxel/'
        rel_path = 'scene_{0}/cropped_voxel_ground_plane/voxel_{1}.npy'.format(results['scene_token'], results['lidar_token'])
        plane_mask = np.load(root_path + rel_path)
    
        occ_np[pcd_np[plane_mask == 0][:, 0], pcd_np[:, 1][plane_mask == 0], pcd_np[:, 2][plane_mask == 0]] = 1
        occ_np[pcd_np[plane_mask == 1][:, 0], pcd_np[:, 1][plane_mask == 1], pcd_np[:, 2][plane_mask == 1]] = -1
        '''
        if self.is_kitti:
            if not os.path.exists(results['pts_filename']):
                #print('kitti label not exist!')
                results['gt_occ'] = np.zeros((256, 256, 32))
            else:
                results['gt_occ'] = np.load(results['pts_filename'])
            return results
        
        if self.load_lidar == 'dense' or self.load_lidar == 'sparse':
            #root_path = '/mnt/cfs/algorithm/temp/semantic_label_nomesh/label_without_reconstruction/'
            if self.load_lidar == 'dense':
                root_path = '/mnt/cfs/algorithm/temp/semantic_label/'
                rel_path = 'scene_{0}/dense_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])
                
            else:
                root_path = '/mnt/cfs/algorithm/temp/train_sparse_semantic_key_frame/'
                rel_path = 'scene_{0}/key_frame_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])

                

            if self.is_test:
                root_path = '/mnt/cfs/algorithm/temp/semantic_label_val0.5_new/val_set_new/'
                rel_path = 'scene_{0}/dense_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])
            pcd = np.load(root_path + rel_path)
            pcd_np = pcd[..., [2,1,0]].astype(np.int)
            occ_np = np.zeros(occ_size_ori).astype(np.int)
            semantics = pcd[..., -1]
            semantics[semantics == 0] = 255
            for i in range(3):
                pcd_np[:, i][pcd_np[:, i] >= occ_size_ori[i]] = occ_size_ori[i] - 1    
            occ_np[pcd_np[:, 0], pcd_np[:, 1], pcd_np[:, 2]] = semantics
            
            
            occ_np_cropped = occ_np[300 - results['occ_size'][0] // 2: 300 + results['occ_size'][0] // 2, \
                                    300 - results['occ_size'][1] // 2: 300 + results['occ_size'][1] // 2, \
                                    24 - results['occ_size'][2] // 2: 24 + results['occ_size'][2] // 2]
    
            occ_np_cropped = np.zeros((200, 200, 16)).astype(np.int)
            if self.is_test or self.load_lidar == 'sparse':
                occ_np_cropped[pcd_np[:, 0], pcd_np[:, 1], pcd_np[:, 2]] = semantics
            else:
                occ_np_cropped[pcd_np[:, 0] // 3, pcd_np[:, 1] // 3, pcd_np[:, 2] // 3] = semantics
            #occ_np_cropped = np.zeros((100, 100, 8)).astype(np.int)
            #occ_np_cropped[pcd_np[:, 0] // 6, pcd_np[:, 1] // 6, pcd_np[:, 2] // 6] = semantics
            results['gt_occ'] = occ_np_cropped

        else:

            import yaml
            label_mapping = '/mnt/cfs/algorithm/temp/BEVFormer/util/nuscenes.yaml'
            with open(label_mapping, 'r') as stream:
                nuscenesyaml = yaml.safe_load(stream)
                learning_map = nuscenesyaml['learning_map']
    
            lidarseg_labels_filename = os.path.join('/mnt/cfs/algorithm/temp/BEVFormer/nus_lidarseg/lidarseg/v1.0-trainval',
                                                        results['lidarseg'])
    
            points_label = np.fromfile(lidarseg_labels_filename, dtype=np.uint8).reshape([-1, 1])
            points_label = np.vectorize(learning_map.__getitem__)(points_label)
    
            pc0 = np.fromfile(results['pts_filename'],
                              dtype=np.float32,
                              count=-1).reshape(-1, 5)[..., :3]
            pcd_np = np.concatenate([pc0, points_label], axis=-1)
            mask = (pcd_np[:, 0] > results['pc_range'][0]) * \
               (pcd_np[:, 1] > results['pc_range'][1]) * \
               (pcd_np[:, 2] > results['pc_range'][2]) * \
               (pcd_np[:, 0] < results['pc_range'][3]) * \
               (pcd_np[:, 1] < results['pc_range'][4]) * \
               (pcd_np[:, 2] < results['pc_range'][5]) 
    
            pcd_np = pcd_np[mask]
            pcd_np[:, 0] = (pcd_np[:, 0] - results['pc_range'][0]) / (results['pc_range'][3] - results['pc_range'][0]) * results['occ_size'][0]
            pcd_np[:, 1] = (pcd_np[:, 1] - results['pc_range'][1]) / (results['pc_range'][4] - results['pc_range'][1]) * results['occ_size'][1]
            pcd_np[:, 2] = (pcd_np[:, 2] - results['pc_range'][2]) / (results['pc_range'][5] - results['pc_range'][2]) * results['occ_size'][2]
    
            for i in range(3):
                pcd_np[:, i][pcd_np[:, i] >= results['occ_size'][i]] = results['occ_size'][i] - 1    
            
            occ_np = np.zeros((results['occ_size'][0], results['occ_size'][1], results['occ_size'][2])).astype(np.int)
            semantics = pcd_np[..., -1]
            semantics[semantics == 0] = 255
            occ_np[pcd_np[:, 0].astype(np.int), pcd_np[:, 1].astype(np.int), pcd_np[:, 2].astype(np.int)] = semantics
            results['gt_occ'] = occ_np
        

        #results['gt_occ'] = np.zeros((results['occ_size'][0], results['occ_size'][1], results['occ_size'][2])).astype(np.int)
        
        
        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32}'
        return repr_str



@PIPELINES.register_module()
class LoadMesh(object):
    """Load multi channel images from a list of separate channel files.

    Expects results['img_filename'] to be a list of filenames.

    Args:
        to_float32 (bool): Whether to convert the img to float32.
            Defaults to False.
        color_type (str): Color type of the file. Defaults to 'unchanged'.
    """

    def __init__(self, to_float32=True, load_semantic=False, is_kitti=False):
        self.to_float32 = to_float32
        self.load_semantic = load_semantic
        self.is_kitti = is_kitti

        

    def __call__(self, results):
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
            dict: The result dict containing the multi-view image data. \
                Added keys and values are described below.

                - filename (str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
        '''
        root_path = '/mnt/cfs/algorithm/temp/point_mesh_voxel/'
        rel_path = 'scene_{0}/cropped_mesh_vertice/mesh_vertice_{1}.ply'.format(results['scene_token'], results['lidar_token'])
        print(rel_path)
        pcd = trimesh.load(root_path + rel_path)
        pcd_np = np.array(pcd.vertices)
        mask = (pcd_np[:, 0] > results['pc_range'][0]) * \
               (pcd_np[:, 1] > results['pc_range'][1]) * \
               (pcd_np[:, 2] > results['pc_range'][2]) * \
               (pcd_np[:, 0] < results['pc_range'][3]) * \
               (pcd_np[:, 1] < results['pc_range'][4]) * \
               (pcd_np[:, 2] < results['pc_range'][5]) 
        pcd_np = pcd_np[mask]
        results['points_occ'] = pcd_np
        '''

        # occ_size_ori = [600, 600, 48]
        
        # root_path = '/mnt/cfs/algorithm/temp/semantic_label/'
        # rel_path = 'scene_{0}/vertice/{1}.npy'.format(results['scene_token'], results['lidar_token'])
        # pcd = np.load(root_path + rel_path)
    
        # gt_mask = (pcd[:, 0] > results['pc_range'][0]) * \
        #       (pcd[:, 0] < results['pc_range'][3]) * \
        #       (pcd[:, 1] > results['pc_range'][1]) * \
        #       (pcd[:, 1] < results['pc_range'][4]) * \
        #       (pcd[:, 2] > results['pc_range'][2]) * \
        #       (pcd[:, 2] < results['pc_range'][5])
        # pcd = pcd[gt_mask]
        # results['points_occ'] = pcd
        

        #results['points_occ'] = np.zeros((10,3))


        # root_path = '/mnt/cfs/algorithm/temp/semantic_label/'
        # rel_path = 'scene_{0}/dense_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])
        # pcd = np.load(root_path + rel_path)
        # pcd_np = pcd[..., [2,1,0]].astype(np.int)
        # occ_np = np.zeros(occ_size_ori).astype(np.int)
        # semantics = pcd[..., -1]
        # semantics[semantics == 0] = 255
        # for i in range(3):
        #     pcd_np[:, i][pcd_np[:, i] >= occ_size_ori[i]] = occ_size_ori[i] - 1    
        # occ_np[pcd_np[:, 0], pcd_np[:, 1], pcd_np[:, 2]] = 1
        
        
        # occ_np_cropped = occ_np[300 - results['occ_size'][0] // 2: 300 + results['occ_size'][0] // 2, \
        #                         300 - results['occ_size'][1] // 2: 300 + results['occ_size'][1] // 2, \
        #                         24 - results['occ_size'][2] // 2: 24 + results['occ_size'][2] // 2]
        # x = np.linspace(0, occ_np_cropped.shape[0] - 1, occ_np_cropped.shape[0])
        # y = np.linspace(0, occ_np_cropped.shape[1] - 1, occ_np_cropped.shape[1])
        # z = np.linspace(0, occ_np_cropped.shape[2] - 1, occ_np_cropped.shape[2])
            
        # X, Y, Z = np.meshgrid(x, y, z,  indexing='ij')
        # vv = np.stack([X, Y, Z], axis=-1)
        # vv = vv[occ_np_cropped >= 0.5]
        # vv[:, 0] = (vv[:, 0] + 0.5) * (results['pc_range'][3] - results['pc_range'][0]) /  results['occ_size'][0]  + results['pc_range'][0] #+ img_metas['pc_range'][0]
        # vv[:, 1] = (vv[:, 1] + 0.5) * (results['pc_range'][4] - results['pc_range'][1]) /  results['occ_size'][1]  + results['pc_range'][1] #+ img_metas['pc_range'][1]
        # vv[:, 2] = (vv[:, 2] + 0.5) * (results['pc_range'][5] - results['pc_range'][2]) /  results['occ_size'][2]  + results['pc_range'][2] #+ img_metas['pc_range'][2]
        # results['points_occ'] = vv

        if self.is_kitti:
            if not os.path.exists(results['pts_filename']):
                #print('kitti label not exist!')
                results['gt_semantic'] = np.zeros((256, 256, 32))
            else:
                results['gt_semantic'] = np.load(results['pts_filename'])
            return results
        
        if self.load_semantic:
            
            '''
            import yaml
            label_mapping = '/mnt/cfs/algorithm/temp/BEVFormer/util/nuscenes.yaml'
            with open(label_mapping, 'r') as stream:
                nuscenesyaml = yaml.safe_load(stream)
                learning_map = nuscenesyaml['learning_map']

            lidarseg_labels_filename = os.path.join('/mnt/cfs/algorithm/temp/BEVFormer/nus_lidarseg/lidarseg/v1.0-trainval',
                                                        results['lidarseg'])

            points_label = np.fromfile(lidarseg_labels_filename, dtype=np.uint8).reshape([-1, 1])
            points_label = np.vectorize(learning_map.__getitem__)(points_label)

            pc0 = np.fromfile(results['pts_filename'],
                              dtype=np.float32,
                              count=-1).reshape(-1, 5)[..., :3]
            pcd_np = np.concatenate([pc0, points_label], axis=-1)
            mask = (pcd_np[:, 0] > results['pc_range'][0]) * \
               (pcd_np[:, 1] > results['pc_range'][1]) * \
               (pcd_np[:, 2] > results['pc_range'][2]) * \
               (pcd_np[:, 0] < results['pc_range'][3]) * \
               (pcd_np[:, 1] < results['pc_range'][4]) * \
               (pcd_np[:, 2] < results['pc_range'][5]) 


            pcd_np = pcd_np[mask]
            
            pcd_np[:, 0] = (pcd_np[:, 0] - results['pc_range'][0]) / (results['pc_range'][3] - results['pc_range'][0]) * results['occ_size'][0]
            pcd_np[:, 1] = (pcd_np[:, 1] - results['pc_range'][1]) / (results['pc_range'][4] - results['pc_range'][1]) * results['occ_size'][1]
            pcd_np[:, 2] = (pcd_np[:, 2] - results['pc_range'][2]) / (results['pc_range'][5] - results['pc_range'][2]) * results['occ_size'][2]

            for i in range(3):
                pcd_np[:, i][pcd_np[:, i] >= results['occ_size'][i]] = results['occ_size'][i] - 1
            occ_np_cropped = np.ones((200, 200, 16)).astype(np.int) * 255
            occ_np_cropped[pcd_np[:, 0].astype(np.int32), pcd_np[:, 1].astype(np.int32), pcd_np[:, 2].astype(np.int32)] = pcd_np[:, 3]
            results['gt_semantic'] = occ_np_cropped
            '''

            
            #root_path = '/mnt/cfs/algorithm/temp/semantic_label/'
            #root_path = '/mnt/cfs/algorithm/temp/semantic_label_val0.5/val_set_0.5_voxel_size/'
            root_path = '/mnt/cfs/algorithm/temp/semantic_label_val0.5_new/val_set_new/'
            
            rel_path = 'scene_{0}/dense_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])
            pcd = np.load(root_path + rel_path)
            pcd_np = pcd[..., [2,1,0,3]].astype(np.int)
            #pcd_np[:, :3] = pcd_np[:, :3] // 3
            pcd_np[:, 3][pcd_np[:, 3] == 0] = 255
            occ_np_cropped = np.zeros((200, 200, 16)).astype(np.int)
            occ_np_cropped[pcd_np[:, 0], pcd_np[:, 1], pcd_np[:, 2]] = pcd_np[:, 3]
            # occ_np_cropped = np.zeros((100, 100, 8)).astype(np.int)
            # occ_np_cropped[pcd_np[:, 0] // 2, pcd_np[:, 1] // 2, pcd_np[:, 2] // 2] = pcd_np[:, 3]
            pcd_np = occ_np_cropped

            results['gt_occ'] = pcd_np
            
        else:
            #root_path = '/mnt/cfs/algorithm/temp/semantic_label_val0.5/val_set_0.5_voxel_size/'
            root_path = '/mnt/cfs/algorithm/temp/semantic_label_val0.5_new/val_set_new/'
            rel_path = 'scene_{0}/dense_voxels_with_semantic/{1}.npy'.format(results['scene_token'], results['lidar_token'])
            pcd = np.load(root_path + rel_path)
            pcd = pcd[pcd[..., 3] > 0]
            pcd_np = pcd[..., [2,1,0,3]].astype(np.int)
            occ_np_cropped = np.zeros((200, 200, 16)).astype(np.int)
            occ_np_cropped[pcd_np[:, 0], pcd_np[:, 1], pcd_np[:, 2]] = pcd_np[:, 3]
            occ_np_cropped[occ_np_cropped > 0] = 1
            results['gt_occ'] = occ_np_cropped
            #results['gt_semantic'] = np.zeros((200, 200, 16)).astype(np.int)
        
        #results['points_occ'] = np.ones((100, 3)).astype(np.float32)
        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f'(to_float32={self.to_float32}'
        return repr_str
