#Code adapted from:  https://github.com/uzh-rpg/rpg_ev-transfer with modification
import os
import torch
import random
import numpy as np
import datasets.data_util as data_util
from datasets.caltech101_loader import Caltech101RGB
import torchvision.transforms as transforms


class NCaltech101Events(Caltech101RGB):
    def getRootPath(self, root):
        """Function makes it easier to handle child of this class e.g. N-Caltech101 dataloader"""
        self.extended_data = False  # There is no extended data for events
        return os.path.join(root, 'N-Caltech101')

    def __getitem__(self, idx):
        """
        returns events and label, loading events from aedat
        :param idx:
        :return: x,y,t,p,  label
        """
        label = self.labels[idx]
        filename = self.files[idx]
        events = np.load(os.path.join(self.root, filename)).astype(np.float32)
        events_pair = np.load(os.path.join(self.root, filename)).astype(np.float32)
        # Convert negative polarity to -1
        events[:, -1] = 2 * events[:, -1] - 1
        events_pair[:, -1] = 2 * events_pair[:, -1] - 1
        nr_events = events.shape[0]
        nr_events_pair = events_pair.shape[0]

        window_start = 0
        window_start_pair = 0

        if self.augmentation:
            events = data_util.random_shift_events(events)
            events_pair = data_util.random_shift_events(events_pair)
            events = data_util.random_flip_events_along_x(events)
            events_pair = data_util.random_shift_events(events_pair)
            nr_events = events.shape[0]
            nr_events_pair = events_pair.shape[0]
            window_start = random.randrange(0, max(1, nr_events - self.nr_events_window))
            window_start_pair = random.randrange(0, max(1, nr_events_pair - self.nr_events_window))

        if self.nr_events_window != -1:
            # Catch case if number of events in batch is lower than number of events in window.
            window_end = min(nr_events, window_start + self.nr_events_window)
            window_end_pair = min(nr_events_pair, window_start_pair + self.nr_events_window)
        else:
            window_start = 0
            window_end = nr_events
            window_end_pair = nr_events_pair
            


        events = events[window_start:window_end, :]
        events_pair = events_pair[window_start_pair:window_end_pair, :]

        event_tensor = data_util.generate_input_representation(events, self.event_representation,
                                                               (self.height, self.width),
                                                               nr_temporal_bins=self.nr_temporal_bins)
        event_tensor_pair = data_util.generate_input_representation(events_pair, self.event_representation,
                                                               (self.height, self.width),
                                                               nr_temporal_bins=self.nr_temporal_bins)
        event_tensor = torch.from_numpy(data_util.normalize_event_tensor(event_tensor))
        event_tensor_pair = torch.from_numpy(data_util.normalize_event_tensor(event_tensor_pair))

        if self.augmentation:
            mid_point = np.max(events[:, :2], axis=0) // 2
            mid_point_pair = np.max(events_pair[:, :2], axis=0) // 2
            event_tensor = data_util.random_crop_resize(event_tensor, mid_point)
            event_tensor_pair = data_util.random_crop_resize(event_tensor_pair, mid_point_pair)

        return event_tensor, label,event_tensor_pair
