from os.path import join, exists
from RandLANetBoosting import Network
from tester_Toronto3D import ModelTester
from helper_ply import read_ply
from helper_tool import Plot
from helper_tool import DataProcessing as DP
from helper_tool import ConfigToronto3D as cfg
import tensorflow as tf
import numpy as np
import pickle, argparse, os


class Toronto3D:
    def __init__(self, mode='train'):
        self.name = 'Toronto3D'
        self.path = 'data/Toronto_3D'
        self.label_to_names = {
            0: 'unclassified',
            1: 'Ground',
            2: 'Road marking',
            3: 'Natural',
            4: 'Building',
            5: 'Utility line',
            6: 'Pole',
            7: 'Car',
            8: 'Fence'
        }
        self.num_classes = len(self.label_to_names)
        self.label_values = np.sort([k for k, v in self.label_to_names.items()])
        self.label_to_idx = {l: i for i, l in enumerate(self.label_values)}
        self.ignored_labels = np.sort([0])

        self.full_pc_folder = join(self.path, 'original_ply')

        # train/val/test
        self.train_files = ['L001', 'L003', 'L004']
        self.val_files   = ['L002']
        self.test_files  = ['L002']

        self.val_split = 3

        self.train_files = [
            os.path.join(self.full_pc_folder, f + '.ply') for f in self.train_files
        ]
        self.val_files = [
            os.path.join(self.full_pc_folder, f + '.ply') for f in self.val_files
        ]
        self.test_files = [
            os.path.join(self.full_pc_folder, f + '.ply') for f in self.test_files
        ]

        # containers
        self.val_proj = []
        self.val_labels = []
        self.test_proj = []
        self.test_labels = []

        self.possibility = {}
        self.min_possibility = {}
        self.class_weight = {}
        self.input_trees = {'training': [], 'validation': [], 'test': []}
        self.input_colors = {'training': [], 'validation': [], 'test': []}
        self.input_labels = {'training': [], 'validation': []}

        self.load_sub_sampled_clouds(cfg.sub_grid_size, mode)

    def load_sub_sampled_clouds(self, sub_grid_size, mode):
        tree_path = join(self.path, 'input_{:.3f}'.format(sub_grid_size))
        if mode == 'test':
            files = self.test_files
        else:
            files = np.hstack((self.train_files, self.val_files))

        for i, file_path in enumerate(files):
            cloud_name = file_path.split('/')[-1][:-4]
            print('Load_pc_' + str(i) + ': ' + cloud_name)
            if mode == 'test':
                cloud_split = 'test'
            else:
                if file_path in self.val_files:
                    cloud_split = 'validation'
                else:
                    cloud_split = 'training'

            kd_tree_file = join(tree_path, '{:s}_KDTree.pkl'.format(cloud_name))
            sub_ply_file = join(tree_path, '{:s}.ply'.format(cloud_name))

            data = read_ply(sub_ply_file)

            if cfg.use_rgb and cfg.use_intensity:
                sub_colors = np.vstack(
                    (data['red'], data['green'], data['blue'], data['intensity'])
                ).T
            elif cfg.use_rgb and not cfg.use_intensity:
                sub_colors = np.vstack(
                    (data['red'], data['green'], data['blue'])
                ).T
            elif not cfg.use_rgb and cfg.use_intensity:
                sub_colors = data['intensity'].reshape(-1,1)
            else:
                sub_colors = np.ones((data.shape[0],1))

            if cloud_split == 'test':
                sub_labels = None
            else:
                sub_labels = data['class']

            # load KDTree
            with open(kd_tree_file, 'rb') as f:
                search_tree = pickle.load(f)

            self.input_trees[cloud_split] += [search_tree]
            self.input_colors[cloud_split] += [sub_colors]

            if cloud_split in ['training', 'validation']:
                self.input_labels[cloud_split] += [sub_labels]

            if cloud_split == 'test':
                print('\nPreparing reprojection indices for {}'.format(cloud_name))
                proj_file = join(tree_path, '{:s}_proj.pkl'.format(cloud_name))
                with open(proj_file, 'rb') as f:
                    proj_idx, labels = pickle.load(f)
                self.test_proj += [proj_idx]
                self.test_labels += [labels]

        print('finished')
        return

    # Generate the input data flow
    def get_batch_gen(self, split):
        if split == 'training':
            num_per_epoch = cfg.train_steps * cfg.batch_size
        elif split == 'validation':
            num_per_epoch = cfg.val_steps * cfg.val_batch_size
        elif split == 'test':
            num_per_epoch = cfg.val_steps * cfg.val_batch_size

        n_features = 1
        if cfg.use_rgb and cfg.use_intensity:
            n_features = 4
        elif cfg.use_rgb and not cfg.use_intensity:
            n_features = 3
        elif not cfg.use_rgb and cfg.use_intensity:
            n_features = 1  # just intensity?

        self.possibility[split] = []
        self.min_possibility[split] = []
        self.class_weight[split] = []

        for i, tree in enumerate(self.input_trees[split]):
            self.possibility[split] += [np.random.rand(tree.data.shape[0]) * 1e-3]
            self.min_possibility[split] += [
                float(np.min(self.possibility[split][-1]))
            ]

        if split != 'test':
            _, num_class_total = np.unique(
                np.hstack(self.input_labels[split]), return_counts=True
            )
            self.class_weight[split] += [
                np.squeeze([num_class_total / np.sum(num_class_total)], axis=0)
            ]

        def spatially_regular_gen():
            # Generator loop
            for i in range(num_per_epoch):
                cloud_idx = int(np.argmin(self.min_possibility[split]))
                point_ind = np.argmin(self.possibility[split][cloud_idx])

                points = np.array(
                    self.input_trees[split][cloud_idx].data, copy=False
                )
                center_point = points[point_ind, :].reshape(1, -1)
                noise = np.random.normal(
                    scale=cfg.noise_init / 10, size=center_point.shape
                )
                pick_point = center_point + noise.astype(center_point.dtype)
                query_idx = self.input_trees[split][cloud_idx].query(
                    pick_point, k=cfg.num_points
                )[1][0]

                from helper_tool import DataProcessing as DP
                query_idx = DP.shuffle_idx(query_idx)

                queried_pc_xyz = points[query_idx]
                queried_pc_xyz[:, :2] = queried_pc_xyz[:, :2] - pick_point[:, :2]
                queried_pc_colors = self.input_colors[split][cloud_idx][query_idx]

                if split == 'test':
                    queried_pc_labels = np.zeros(queried_pc_xyz.shape[0])
                    queried_pt_weight = 1
                else:
                    queried_pc_labels = self.input_labels[split][cloud_idx][query_idx]
                    # convert from raw label to 0..N_class-1
                    queried_pc_labels = np.array(
                        [self.label_to_idx[l] for l in queried_pc_labels]
                    )
                    queried_pt_weight = np.array([
                        self.class_weight[split][0][n] for n in queried_pc_labels
                    ])

                # update possibility
                dists = np.sum(
                    np.square((points[query_idx] - pick_point).astype(np.float32)),
                    axis=1
                )
                delta = np.square(1 - dists / np.max(dists)) * queried_pt_weight
                self.possibility[split][cloud_idx][query_idx] += delta
                self.min_possibility[split][cloud_idx] = float(
                    np.min(self.possibility[split][cloud_idx])
                )

                yield (
                    queried_pc_xyz,
                    queried_pc_colors.astype(np.float32),
                    queried_pc_labels,
                    query_idx.astype(np.int32),
                    np.array([cloud_idx], dtype=np.int32)
                )

        gen_func = spatially_regular_gen
        gen_types = (tf.float32, tf.float32, tf.int32, tf.int32, tf.int32)
        gen_shapes = ([None, 3], [None, n_features], [None], [None], [None])
        return gen_func, gen_types, gen_shapes

    def get_tf_mapping(self):
        def tf_map(batch_xyz, batch_features, batch_labels, batch_pc_idx, batch_cloud_idx):
            batch_features = tf.map_fn(
                self.tf_augment_input,
                [batch_xyz, batch_features],
                dtype=tf.float32
            )
            input_points = []
            input_neighbors = []
            input_pools = []
            input_up_samples = []

            from helper_tool import DataProcessing as DP
            for i in range(cfg.num_layers):
                neigh_idx = tf.py_func(
                    DP.knn_search, [batch_xyz, batch_xyz, cfg.k_n], tf.int32
                )
                sub_points = batch_xyz[
                             :,
                             :tf.shape(batch_xyz)[1] // cfg.sub_sampling_ratio[i],
                             :
                             ]
                pool_i = neigh_idx[
                         :,
                         :tf.shape(batch_xyz)[1] // cfg.sub_sampling_ratio[i],
                         :
                         ]
                up_i = tf.py_func(DP.knn_search, [sub_points, batch_xyz, 1], tf.int32)
                input_points.append(batch_xyz)
                input_neighbors.append(neigh_idx)
                input_pools.append(pool_i)
                input_up_samples.append(up_i)
                batch_xyz = sub_points

            input_list = (
                    input_points + input_neighbors + input_pools + input_up_samples
            )
            input_list += [
                batch_features,
                batch_labels,
                batch_pc_idx,
                batch_cloud_idx
            ]
            return input_list

        return tf_map

    @staticmethod
    def tf_augment_input(inputs):
        xyz = inputs[0]
        features = inputs[1]
        theta = tf.random_uniform((1,), minval=0, maxval=2 * np.pi)

        c, s = tf.cos(theta), tf.sin(theta)
        cs0 = tf.zeros_like(c)
        cs1 = tf.ones_like(c)
        R = tf.stack([c, -s, cs0, s, c, cs0, cs0, cs0, cs1], axis=1)
        R = tf.reshape(R, (3, 3))

        transformed_xyz = tf.reshape(tf.matmul(xyz, R), [-1, 3])

        min_s = cfg.augment_scale_min
        max_s = cfg.augment_scale_max
        if cfg.augment_scale_anisotropic:
            s = tf.random_uniform((1,3), minval=min_s, maxval=max_s)
        else:
            s = tf.random_uniform((1,1), minval=min_s, maxval=max_s)

        symmetries = []
        for i in range(3):
            if cfg.augment_symmetries[i]:
                symmetries.append(tf.round(tf.random_uniform((1,1))) * 2 - 1)
            else:
                symmetries.append(tf.ones([1,1], dtype=tf.float32))
        s *= tf.concat(symmetries, 1)

        stacked_scales = tf.tile(s, [tf.shape(transformed_xyz)[0], 1])
        transformed_xyz = transformed_xyz * stacked_scales

        noise = tf.random_normal(tf.shape(transformed_xyz), stddev=cfg.augment_noise)
        transformed_xyz = transformed_xyz + noise
        stacked_features = tf.concat([transformed_xyz, features], axis=-1)
        return stacked_features

    def init_train_pipeline(self):
        print('Initiating training pipelines')
        cfg.ignored_label_inds = [
            self.label_to_idx[ign_label] for ign_label in self.ignored_labels
        ]
        gen_function, gen_types, gen_shapes = self.get_batch_gen('training')
        gen_function_val, _, _ = self.get_batch_gen('validation')
        self.train_data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes)
        self.val_data   = tf.data.Dataset.from_generator(gen_function_val, gen_types, gen_shapes)

        self.batch_train_data = self.train_data.batch(cfg.batch_size)
        self.batch_val_data   = self.val_data.batch(cfg.val_batch_size)
        map_func = self.get_tf_mapping()

        self.batch_train_data = self.batch_train_data.map(map_func=map_func)
        self.batch_val_data   = self.batch_val_data.map(map_func=map_func)

        self.batch_train_data = self.batch_train_data.prefetch(cfg.batch_size)
        self.batch_val_data   = self.batch_val_data.prefetch(cfg.val_batch_size)

        iter = tf.data.Iterator.from_structure(
            self.batch_train_data.output_types,
            self.batch_train_data.output_shapes
        )
        self.flat_inputs   = iter.get_next()
        self.train_init_op = iter.make_initializer(self.batch_train_data)
        self.val_init_op   = iter.make_initializer(self.batch_val_data)

    def init_test_pipeline(self):
        print('Initiating testing pipelines')
        cfg.ignored_label_inds = [
            self.label_to_idx[ign_label] for ign_label in self.ignored_labels
        ]
        gen_function_test, gen_types, gen_shapes = self.get_batch_gen('test')
        self.test_data = tf.data.Dataset.from_generator(gen_function_test, gen_types, gen_shapes)
        self.batch_test_data = self.test_data.batch(cfg.val_batch_size)
        map_func = self.get_tf_mapping()
        self.batch_test_data = self.batch_test_data.map(map_func=map_func)
        self.batch_test_data = self.batch_test_data.prefetch(cfg.val_batch_size)

        iter = tf.data.Iterator.from_structure(
            self.batch_test_data.output_types,
            self.batch_test_data.output_shapes
        )
        self.flat_inputs  = iter.get_next()
        self.test_init_op = iter.make_initializer(self.batch_test_data)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='the number of GPUs to use [default: 0]')
    parser.add_argument('--mode', type=str, default='test', help='options: train, test, vis')
    parser.add_argument('--model_path', type=str, default='None', help='pretrained model path')
    parser.add_argument('--test_eval', type=bool, default=True, help='evaluate test result on L002')
    FLAGS = parser.parse_args()

    GPU_ID = FLAGS.gpu
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    Mode = FLAGS.mode
    dataset = Toronto3D(mode=Mode)

    if Mode == 'train':
        dataset.init_train_pipeline()
        model = Network(dataset, cfg)
        model.train(dataset)
    elif Mode == 'test':
        cfg.saving = False
        dataset.init_test_pipeline()
        model = Network(dataset, cfg)
        if FLAGS.model_path is not 'None':
            chosen_snap = FLAGS.model_path
        else:
            chosen_snapshot = -1
            logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')])
            chosen_folder = logs[-1]
            snap_path = join(chosen_folder, 'snapshots')
            snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta']
            chosen_step = np.sort(snap_steps)[chosen_snapshot]
            chosen_snap = os.path.join(snap_path, 'snap-{:d}'.format(chosen_step))
        tester = ModelTester(model, dataset, cfg, restore_snap=chosen_snap)
        if FLAGS.test_eval:
            tester.test(model, dataset, eval=True)
        else:
            tester.test(model, dataset)

    else:
        ##################
        # Visualize data #
        ##################

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(dataset.train_init_op)
            while True:
                flat_inputs = sess.run(dataset.flat_inputs)
                pc_xyz = flat_inputs[0]
                sub_pc_xyz = flat_inputs[1]
                labels = flat_inputs[21]
                Plot.draw_pc_sem_ins(pc_xyz[0, :, :], labels[0, :])
                Plot.draw_pc_sem_ins(sub_pc_xyz[0, :, :], labels[0, 0:np.shape(sub_pc_xyz)[1]])
