#Code adapted from:  https://github.com/uzh-rpg/rpg_ev-transfer with modification
from __future__ import division
import torch
from tqdm import tqdm
tqdm.monitor_interval = 0
import numpy as np
from tensorboardX import SummaryWriter

from utils.saver import CheckpointSaver

from datasets.wrapper_dataloader import WrapperDataset
import utils.viz_utils as viz_utils


class BaseTrainer(object):
    """BaseTrainer class to be inherited"""
    def __init__(self, settings):
        self.settings = settings
        self.device = self.settings.gpu_device
        self.do_val_training_epoch = True

        self.init_fn()
        self.createDataLoaders()

        self.models_dict = {k: v.to(self.device) for k, v in self.models_dict.items()}

        # tensorboardX SummaryWriter for use in train_summaries
        self.summary_writer = SummaryWriter(self.settings.ckpt_dir)

        # Load the latest checkpoints
        load_optimizer = False
        if self.settings.resume_training:
            load_optimizer = False

            self.saver = CheckpointSaver(save_dir=settings.ckpt_dir)
            self.checkpoint = self.saver.load_checkpoint(self.models_dict,
                                                         self.optimizers_dict,
                                                         checkpoint_file=self.settings.resume_ckpt_file,
                                                         load_optimizer=load_optimizer)
            self.epoch_count = self.checkpoint['epoch']
            self.step_count = self.checkpoint['step_count']

        else:
            self.saver = CheckpointSaver(save_dir=settings.ckpt_dir)
            self.epoch_count = 0
            self.step_count = 0
            self.checkpoint = None

        self.epoch = self.epoch_count

        optimizer_epoch_count = self.epoch_count if load_optimizer else 0
        self.lr_schedulers = {k: torch.optim.lr_scheduler.ExponentialLR(v, gamma=self.settings.lr_decay,
                                                                        last_epoch=optimizer_epoch_count-1)
                              for k, v in self.optimizers_dict.items()}


    def init_fn(self):
        """Model is constructed in child class"""
        pass

    def getDataloader(self, dataset_name):
        """Returns the dataset loader specified in the settings file"""
        if dataset_name == 'Caltech101_gray':
            from datasets.caltech101_loader import Caltech101Gray
            return Caltech101Gray
        elif dataset_name == 'NCaltech101_events':
            from datasets.ncaltech101_loader import NCaltech101Events
            return NCaltech101Events


    def createDataset(self, dataset_name, dataset_path, img_size, batch_size, nr_events_window, event_representation,
                      nr_temporal_bins):
        """
        Creates the validation and the training data based on the provided paths and parameters.
        """
        dataset_builder = self.getDataloader(dataset_name)

        train_dataset1 = dataset_builder(dataset_path,
                                        height=img_size[0],
                                        width=img_size[1],
                                        nr_events_window=nr_events_window,
                                        augmentation=True,
                                        mode='train',
                                        event_representation=event_representation,
                                        nr_temporal_bins=nr_temporal_bins)

        val_dataset = dataset_builder(dataset_path,
                                      height=img_size[0],
                                      width=img_size[1],
                                      nr_events_window=nr_events_window,
                                      augmentation=True,
                                      mode='val',
                                      event_representation=event_representation,
                                      nr_temporal_bins=nr_temporal_bins)
        
        train_dataset = val_dataset + train_dataset1 #combine training set and validation set

        dataset_loader = torch.utils.data.DataLoader
        train_loader_sensor = dataset_loader(train_dataset, batch_size=batch_size,
                                             num_workers=self.settings.num_cpu_workers,
                                             pin_memory=False, shuffle=True, drop_last=True)


        return train_loader_sensor

    def combineDataloaders(self):
        """Combines two dataloader to one dataloader."""
        self.train_loader = WrapperDataset(self.train_loader_sensor_a, self.train_loader_sensor_b, self.device)

    def createDataLoaders(self):
        out = self.createDataset(self.settings.dataset_name_a,
                                 self.settings.dataset_path_a,
                                 self.settings.img_size_a,
                                 self.settings.batch_size_a,
                                 self.settings.nr_events_window_a,
                                 self.settings.event_representation_a,
                                 self.settings.input_channels_a // 2)
        self.train_loader_sensor_a = out
        


        out = self.createDataset(self.settings.dataset_name_b,
                                 self.settings.dataset_path_b,
                                 self.settings.img_size_b,
                                 self.settings.batch_size_b,
                                 self.settings.nr_events_window_b,
                                 self.settings.event_representation_b,
                                 self.settings.input_channels_b // 2)
        self.train_loader_sensor_b = out
        
      
        

        self.combineDataloaders()
        print(self.train_loader.__len__())
        self.test_loader_sensor = self.createTestDataset(self.settings.dataset_name_b,
                                                     self.settings.dataset_path_b,
                                                     self.settings.img_size_b,
                                                     self.settings.batch_size_b,
                                                     self.settings.nr_events_window_b,
                                                     self.settings.event_representation_b,
                                                     self.settings.input_channels_b // 2)
        self.test_loader_sensor_a = self.createTestDataset(self.settings.dataset_name_a,
                                                     self.settings.dataset_path_a,
                                                     self.settings.img_size_a,
                                                     self.settings.batch_size_a,
                                                     self.settings.nr_events_window_a,
                                                     self.settings.event_representation_a,
                                                     self.settings.input_channels_a // 2)
#         seed = 10
#         torch.manual_seed(seed)



    def train(self):
        """Main training and validation loop"""
        val_epoch_step = 2
        if self.settings.dataset_name_b in ['MVSEC_events', 'OneMpProphesee_events']:
            val_epoch_step = 1
        clas_good_enough = False
        for _ in tqdm(range(self.epoch_count, self.settings.num_epochs), total=self.settings.num_epochs,
                      initial=self.epoch_count):
            
            if (self.epoch_count % val_epoch_step) == 0:
                print(self.epoch_count)
                
                self.evaluate_testset()


            self.trainEpoch(clas_good_enough)

            if self.epoch_count % 10 == 0:
                self.saver.save_checkpoint(self.models_dict,
                                           self.optimizers_dict, self.epoch_count, self.step_count,
                                           self.settings.batch_size_a,
                                           self.settings.batch_size_b)
                tqdm.write('Checkpoint saved')

            # apply the learning rate scheduling policy
            for opt in self.optimizers_dict:
                self.lr_schedulers[opt].step()
            self.epoch_count += 1



    def trainEpoch(self,clas_good_enough=False):
    
        self.train_loader.createIterators()
        for model in self.models_dict:
            self.models_dict[model].train()

        for i_batch, sample_batched in enumerate(self.train_loader):
            self.train_step(sample_batched,clas_good_enough)



            self.step_count += 1
  
        
      
    def evaluate_testset(self):
        self.eval_backend = self.models_dict["back_end"]
        self.eval_front_b = self.models_dict['front_sensor_b']
        self.test_loader_sensor = self.createTestDataset(self.settings.dataset_name_b,
                                                     self.settings.dataset_path_b,
                                                     self.settings.img_size_b,
                                                     self.settings.batch_size_b,
                                                     self.settings.nr_events_window_b,
                                                     self.settings.event_representation_b,
                                                     self.settings.input_channels_b // 2)
        self.models_dict = {k: v.to(self.device) for k, v in self.models_dict.items()}
        with torch.no_grad():
            for model in self.models_dict:
                self.models_dict[model].eval()

            val_dataset_length = self.test_loader_sensor.__len__()
          

            test_accuracies = []
            for i_batch, sample_batched in enumerate(self.test_loader_sensor):
                
                test_accuracies = self.testBatchStep_testset(sample_batched, test_accuracies)
                #self.pbar.update(1)

        print('Test Accuracy')
        print(np.mean(np.array(test_accuracies).astype(np.float)))

    def testBatchStep_testset(self, sample_batched, test_accuracies):
        sample_batched = [tensor.to(self.device) for tensor in sample_batched]
        data = sample_batched[0]
        labels = sample_batched[1]

        (content_features1,content_features2), _, _, _ = self.eval_front_b(data)
        content_features = content_features1+content_features2
        pred = self.eval_backend(content_features)

        correct_predictions = torch.eq(torch.argmax(pred, dim=-1), labels).detach().cpu().numpy().tolist()
        test_accuracies += correct_predictions

        return test_accuracies
    
    
    
    def createTestDataset(self, dataset_name, dataset_path, img_size, batch_size, nr_events_window, event_representation,
                      nr_temporal_bins):
        """
        Creates the validation and the training data based on the provided paths and parameters.
        """
        dataset_builder = self.getDataloader(dataset_name)

        test_dataset = dataset_builder(dataset_path,
                                       height=img_size[0],
                                       width=img_size[1],
                                       nr_events_window=nr_events_window,
                                       augmentation=False,
                                       mode='test',
                                       event_representation=event_representation,
                                       nr_temporal_bins=nr_temporal_bins)

        self.object_classes = test_dataset.class_list

        dataset_loader = torch.utils.data.DataLoader
        test_loader_sensor = dataset_loader(test_dataset, batch_size=int(150),
                                            num_workers=self.settings.num_cpu_workers,
                                            pin_memory=False, shuffle=False, drop_last=False)

        return test_loader_sensor
    
     ## visulize the fake events
    def visul(self):
        self.eval_front_a = self.models_dict['front_sensor_a']
        self.eval_front_b = self.models_dict['front_sensor_b']
        self.decoder =self.models_dict["decoder_sensor_b"]
        self.refinement = self.models_dict["cross_refinement_net_sensor_b"]
        test_loader = WrapperDataset(self.test_loader_sensor_a, self.test_loader_sensor, self.device)
        with torch.no_grad():
            for model in self.models_dict:
                self.models_dict[model].eval()
            for i_batch, sample_batched in enumerate(test_loader):
                input_image,input_RGB,input_event,output_event = self.generateImage(sample_batched)
                return input_image,input_RGB,input_event,output_event
    def generateImage(self,batch):
        data_a = batch[0][0]
        RGB_a = batch[0][-1]
        data_b = batch[1][0]


        (content_features1,content_features2), _, _, _ = self.eval_front_a(data_a)
        content_features = content_features1+content_features2
        _, _, _, attribute_sensor_b = self.eval_front_b(data_b)
        
        decoder_output = self.decoder.forward(content_features, attribute_sensor_b)
        fake_event = self.refinement.forward(decoder_output, data_a)
        
        input_image = data_a.cpu().clone()
        input_RGB = RGB_a.cpu().clone()
        input_event = viz_utils.visualizeFlow(data_b).cpu().clone()
        output_event = viz_utils.visualizeFlow(fake_event).cpu().clone()
        return input_image,input_RGB,input_event,output_event
 
    def visualizeHistogram(self,histogram):
        """Visualizes the input histogram"""
        batch, _, height, width = histogram.shape
        torch_image = torch.zeros([batch, 1, height, width], device=histogram.device)

        return torch.cat([histogram.clamp(0, 1), torch_image], dim=1)
