import os
import torch
import random
import numpy as np
import tqdm
import math
import torchvision.transforms as transforms
from typing import  Callable, Dict, Optional, Tuple
from torchvision.datasets import DatasetFolder
#Code adapted from: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/datasets with modification

class NeuromorphicDatasetFolder(DatasetFolder):
    def __init__(
            self,
            root: str,
            train: bool = None,
            data_type: str = 'event',
            frames_number: int = None,
            split_by: str = None,
            duration: int = None,
            custom_integrate_function: Callable = None,
            custom_integrated_frames_dir_name: str = None,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
        '''
        :param root: root path of the dataset
        :type root: str
        :param train: whether use the train set. Set ``True`` or ``False`` for those datasets provide train/test
            division, e.g., DVS128 Gesture dataset. If the dataset does not provide train/test division, e.g., CIFAR10-DVS,
            please set ``None`` and use :class:`~split_to_train_test_set` function to get train/test set
        :type train: bool
        :param data_type: `event` or `frame`
        :type data_type: str
        :param frames_number: the integrated frame number
        :type frames_number: int
        :param split_by: `time` or `number`
        :type split_by: str
        :param duration: the time duration of each frame
        :type duration: int
        :param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
            ``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
            ``H`` is the height of the data and ``W`` is the weight of the data.
            For example, H=128 and W=128 for the DVS128 Gesture dataset.
            The user should define how to integrate events to frames, and return frames.
        :type custom_integrate_function: Callable
        :param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
            If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
        :type custom_integrated_frames_dir_name: str or None
        :param transform: a function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        :type transform: callable
        :param target_transform: a function/transform that takes
            in the target and transforms it.
        :type target_transform: callable
        The base class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing
        all abstract methods. Users can refer to :class:`spikingjelly.datasets.dvs128_gesture.DVS128Gesture`.
        If ``data_type == 'event'``
            the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.
        If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
            events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
            See :class:`cal_fixed_frames_number_segment_index` for
            more details.
        If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
            events will be integrated to frames with fixed time duration.
        If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
            events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
            Here is an example from SpikingJelly's tutorials:
            .. code-block:: python
                from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
                from typing import Dict
                import numpy as np
                import spikingjelly.datasets as sjds
                def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
                    index_split = np.random.randint(low=0, high=events['t'].__len__())
                    frames = np.zeros([2, 2, H, W])
                    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
                    frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
                    frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
                    return frames
                root_dir = 'D:/datasets/DVS128Gesture'
                train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)
                from spikingjelly.datasets import play_frame
                frame, label = train_set[500]
                play_frame(frame)
        '''

        events_np_root = root



        H, W = self.get_H_W()

        if data_type == 'event':
            _root = events_np_root
            _loader = np.load
            _transform = transform
            _target_transform = target_transform



        super().__init__(root=_root, loader=_loader, extensions=('.npz', ), transform=_transform,
                         target_transform=_target_transform)

class DVSCIFA10Dataset(NeuromorphicDatasetFolder):
    def __init__(self, root, augmentation = False):
        """
        Creates an iterator over the DFSCIFA10 dataset.

        :param root: path to dataset root
        :param height: height of dataset image
        :param width: width of dataset image
        :param nr_events_window: number of events summed in the sliding histogram
        :param augmentation: flip, shift and random window start for training
        :param mode: 'train', 'test' 
        """
        super().__init__(root, None)
        self.augmentation =augmentation 
    def get_H_W(self):
        '''
        :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
            For CIFA10-DVS the input size is 128*128.
        :rtype: tuple
        '''
        return 128, 128

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        
        
        
        def random_shift_events(events, max_shift=10, resolution=(128, 128), bounding_box=None):
            """Randomly shift events and crops """
            H, W = resolution
            if bounding_box is not None:
                x_shift = np.random.randint(-min(bounding_box[0, 0], max_shift),
                                            min(W - bounding_box[2, 0], max_shift), size=(1,))
                y_shift = np.random.randint(-min(bounding_box[0, 1], max_shift),
                                            min(H - bounding_box[2, 1], max_shift), size=(1,))
                bounding_box[:, 0] += x_shift
                bounding_box[:, 1] += y_shift
            else:
                x_shift, y_shift = np.random.randint(-max_shift, max_shift+1, size=(2,))

            events[:, 0] += x_shift
            events[:, 1] += y_shift

            valid_events = (events[:, 0] >= 0) & (events[:, 0] < W) & (events[:, 1] >= 0) & (events[:, 1] < H)
            events = events[valid_events]

            if bounding_box is None:
                return events

            return events, bounding_box


        def random_flip_events_along_x(events, resolution=(128, 128), p=0.5, bounding_box=None):
            H, W = resolution
            flipped = False
            if np.random.random() < p:
                events[:, 0] = W - 1 - events[:, 0]
                flipped = True

            if bounding_box is None:
                return events

            if flipped:
                bounding_box[:, 0] = W - 1 - bounding_box[:, 0]
                bounding_box = bounding_box[[1, 0, 3, 2]]

            return events, bounding_box


        def generate_input_representation(events, event_representation, shape, nr_temporal_bins=7):
            """
            Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image
            coordinates u and v.
            """
            if event_representation == 'histogram':
                return generate_event_histogram(events, shape)
            elif event_representation == 'voxel_grid':
                return generate_voxel_grid(events, shape, nr_temporal_bins)


        def generate_event_histogram(events, shape):
            """
            Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image
            coordinates u and v.
            """
            height, width = shape
            x, y, t, p = events.T
            x = x.astype(np.int)
            y = y.astype(np.int)
            img_pos = np.zeros((156 * 156,), dtype="float32")
            img_neg = np.zeros((156 * 156,), dtype="float32")
            try:
                np.add.at(img_pos, x[p == 1] + width * y[p == 1], 1)
            except:
                print(1)
                print(img_pos)
            try:
                np.add.at(img_neg, x[p == -1] + width * y[p == -1], 1)
            except:
                print(1)
                print(img_neg)
                
            img_pos = img_pos[0:16384]
            img_neg = img_neg[0:16384]

            histogram = np.stack([img_neg, img_pos], 0).reshape((2, height, width))

            return histogram


        def normalize_event_tensor(event_tensor):
            """Normalize the sensor according the 98 quantile"""
            event_volume_flat = event_tensor.flatten()
            nonzero = np.nonzero(event_volume_flat)
            nonzero_values = event_volume_flat[nonzero]
            if nonzero_values.shape[0]:
                max_val = np.percentile(nonzero_values, 98, interpolation='nearest')
                event_tensor = np.clip(event_tensor, 0, max_val)
                event_tensor /= max_val

            return event_tensor


        
        def random_crop_resize(tensor, mid_point, crop_range=[-10, 10], scale_range=[0.8, 1]):
            """Randomly crops a tensor based on the specified mid_point and height and scale range"""
            _, height, width = tensor.shape
            random_delta = torch.rand([2], device='cpu').numpy() * (crop_range[1] - crop_range[0]) + crop_range[0]
            random_scale = torch.rand([2], device='cpu').numpy() * (scale_range[1] - scale_range[0]) + scale_range[0]

            random_delta = np.minimum(random_delta, mid_point) * (random_delta >= 0) + \
                           np.maximum(random_delta, mid_point - np.array([height, width])) * (random_delta < 0)

            left_corner_u = int(np.maximum(0, random_delta[0]))
            left_corner_v = int(np.maximum(0, random_delta[1]))
            right_corner_u = int(np.minimum(height, random_delta[0] + height * random_scale[0]))
            right_corner_v = int(np.minimum(width, random_delta[1] + width * random_scale[1]))

            tensor = tensor[:, left_corner_u:right_corner_u, left_corner_v:right_corner_v]
            tensor = torch.nn.functional.interpolate(tensor[None, :, :, :], (height, width)).squeeze(axis=0)

            return tensor
        
        
        
        
        path, target = self.samples[index]
        x = self.loader(path)
        events = np.concatenate([[x['x']],[x['y']],[x['t']],[x['p']]],axis = 0).T
        events_pair = np.concatenate([[x['x']],[x['y']],[x['t']],[x['p']]],axis = 0).T
        
        
        events[:, -1] = 2 * events[:, -1] - 1
        events_pair[:, -1] = 2 * events_pair[:, -1] - 1
        nr_events = events.shape[0]
        nr_events_pair = events_pair.shape[0]

        window_start = 0
        window_start_pair = 0
        
        self.nr_events_window = 80000
        if self.augmentation:
            events = random_shift_events(events)
            events_pair = random_shift_events(events_pair)
            events = random_flip_events_along_x(events)
            events_pair = random_flip_events_along_x(events_pair)
            nr_events = events.shape[0]
            nr_events_pair = events_pair.shape[0]
            window_start = random.randrange(0, max(1, nr_events - self.nr_events_window))
            window_start_pair = random.randrange(0, max(1, nr_events_pair - self.nr_events_window))
        
        
        if self.nr_events_window != -1:
            # Catch case if number of events in batch is lower than number of events in window.
            window_end = min(nr_events, window_start + self.nr_events_window)
            window_end_pair = min(nr_events_pair, window_start_pair + self.nr_events_window)
        else:
            window_start = 0
            window_end = nr_events
            window_end_pair = nr_events_pair
            


        events = events[window_start:window_end, :]
        events_pair = events_pair[window_start_pair:window_end_pair, :]
        
        self.nr_temporal_bins = 5
        event_tensor = generate_input_representation(events, 'histogram',
                                                               (128, 128),
                                                               nr_temporal_bins=self.nr_temporal_bins)
        event_tensor_pair = generate_input_representation(events_pair, 'histogram',
                                                               (128, 128),
                                                               nr_temporal_bins=self.nr_temporal_bins)
        event_tensor = torch.from_numpy(normalize_event_tensor(event_tensor))
        event_tensor_pair = torch.from_numpy(normalize_event_tensor(event_tensor_pair))

        if self.augmentation:
            mid_point = np.max(events[:, :2], axis=0) // 2
            mid_point_pair = np.max(events_pair[:, :2], axis=0) // 2
            event_tensor = random_crop_resize(event_tensor, mid_point)
            event_tensor_pair = random_crop_resize(event_tensor_pair, mid_point_pair)


        return event_tensor, target,event_tensor_pair