#Code adapted from: https://github.com/uzh-rpg/rpg_ev-transfer with modification
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torchvision.models as models

from utils import pytorch_ssim, radam
import utils.viz_utils as viz_utils
from datasets.wrapper_dataloader import WrapperDataset
from utils.loss_functions import event_reconstruction_loss,Ra_discriminator_loss,Ra_generator_adversarial_loss
from models import refinement_net_gray as refinement_net
from models.style_networks import StyleEncoder, StyleDecoder, ContentDiscriminator, CrossDiscriminator
import training.base_trainer
from utils.sobel_filter_gray import NormGradient, GaussianSmoothing


class ClassificationModel(training.base_trainer.BaseTrainer):
    def __init__(self, settings, train=True):
        self.is_training = train
        super(ClassificationModel, self).__init__(settings)

        self.norm_gradient_layer = NormGradient(self.device, ignore_border=True)
        self.gaussian_layer = GaussianSmoothing(channels=1, kernel_size=[3, 3], sigma=[1, 1], device=self.device)

    def init_fn(self):
        self.buildModels()
        self.createOptimizerDict()

        # Decoder Loss
        self.ssim = pytorch_ssim.SSIM()
        secondary_l1 = nn.L1Loss(reduction="mean")
        image_loss = lambda x, y: secondary_l1(x, y) - self.ssim(x, y)
        reconst_loss_dict = {'gray': image_loss, 'events': event_reconstruction_loss}

        self.reconst_loss_dict = {'sensor_a': reconst_loss_dict.get(self.settings.sensor_a_name),
                                  'sensor_b': reconst_loss_dict.get(self.settings.sensor_b_name)}

        self.cycle_content_loss = torch.nn.L1Loss()

        self.cycle_attribute_loss = torch.nn.L1Loss()

        # Task Loss
        self.task_loss = nn.CrossEntropyLoss()
        self.train_statistics = {}

    def buildModels(self):
        attribute_channels = 128

        # Shared Encoder Layers
        self.front_end_shared = list(models.resnet18(pretrained=True).children())[5][1]

        # Front End Sensor A
        self.front_end_sensor_a = StyleEncoder(1, self.front_end_shared,
                                               attribute_channels, self.settings.use_decoder_a)

        # Front End Sensor B
        self.front_end_sensor_b = StyleEncoder(self.settings.input_channels_b, self.front_end_shared,
                                               attribute_channels, self.settings.use_decoder_b)

        # Discriminator
        self.discriminator = ContentDiscriminator(nr_channels=128)

        self.models_dict = {"front_sensor_a": self.front_end_sensor_a,
                            "front_sensor_b": self.front_end_sensor_b,
                            "front_shared": self.front_end_shared,
                            "dis": self.discriminator}


        # Task Backend
        if self.settings.use_task_a or self.settings.use_task_b:
            self.task_backend = nn.Sequential(*(list(models.resnet18(pretrained=True).children())[6:-1] +
                                                [nn.Flatten(), nn.Linear(512, 10)]))
            self.models_dict["back_end"] = self.task_backend

        # Decoders
        if self.settings.use_decoder_a:
            self.decoder_sensor_a = StyleDecoder(input_c=128, output_c=self.settings.input_channels_a,
                                                 attribute_channels=attribute_channels,
                                                 sensor_name=self.settings.sensor_a_name)
            self.models_dict["decoder_sensor_a"] = self.decoder_sensor_a

        if self.settings.use_decoder_b:
            self.decoder_sensor_b = StyleDecoder(input_c=128, output_c=2,
                                                 attribute_channels=attribute_channels,
                                                 sensor_name=self.settings.sensor_b_name)
            self.models_dict["decoder_sensor_b"] = self.decoder_sensor_b

        # Cross Refinement Networks
        if self.settings.cross_refinement_a:
            self.cross_refinement_net_a = refinement_net.StyleRefinementNetwork(input_c=2,
                                                                                output_c=self.settings.input_channels_a,
                                                                                sensor=self.settings.sensor_a_name,
                                                                                channel_list=[16, 8],
                                                                                last_layer_pad=1,
                                                                                device=self.device)
            self.refinement_discr_a = CrossDiscriminator(input_dim=self.settings.input_channels_a, n_layer=6)

            self.models_dict["cross_refinement_net_sensor_a"] = self.cross_refinement_net_a
            self.models_dict["refinement_discr_sensor_a"] = self.refinement_discr_a

        if self.settings.cross_refinement_b:
            self.cross_refinement_net_b = refinement_net.StyleRefinementNetwork(input_c=2,
                                                                                output_c=self.settings.input_channels_b,
                                                                                sensor=self.settings.sensor_b_name,
                                                                                channel_list=[16, 8],
                                                                                last_layer_pad=1,
                                                                                device=self.device)
            self.refinement_discr_b = CrossDiscriminator(input_dim=self.settings.input_channels_b, n_layer=6)

            self.models_dict["cross_refinement_net_sensor_b"] = self.cross_refinement_net_b
            self.models_dict["refinement_discr_sensor_b"] = self.refinement_discr_b

    def createOptimizerDict(self):
        """Creates the dictionary containing the optimizer for the the specified subnetworks"""
        if not self.is_training:
            self.optimizers_dict = {}
            return

        dis_params = filter(lambda p: p.requires_grad, self.discriminator.parameters())
        front_sensor_a_params = filter(lambda p: p.requires_grad, self.front_end_sensor_a.parameters())
        front_sensor_b_params = filter(lambda p: p.requires_grad, self.front_end_sensor_b.parameters())
        front_shared_params = filter(lambda p: p.requires_grad, self.front_end_shared.parameters())

        weight_decay = 0.01
        optimizer_dis = radam.RAdam(dis_params,
                                    lr=self.settings.lr_discriminator,
                                    weight_decay=weight_decay,
                                    betas=(0., 0.999))
        optimizer_front_sensor_a = radam.RAdam(front_sensor_a_params,
                                               lr=self.settings.lr_front,
                                               weight_decay=weight_decay,
                                               betas=(0., 0.999))
        optimizer_front_sensor_b = radam.RAdam(front_sensor_b_params,
                                               lr=self.settings.lr_front,
                                               weight_decay=weight_decay,
                                               betas=(0., 0.999))
        optimizer_front_shared = radam.RAdam(front_shared_params,
                                             lr=self.settings.lr_front,
                                             weight_decay=weight_decay,
                                             betas=(0., 0.999))

        self.optimizers_dict = {"optimizer_front_sensor_a": optimizer_front_sensor_a,
                                "optimizer_front_sensor_b": optimizer_front_sensor_b,
                                "optimizer_front_shared": optimizer_front_shared,
                                "optimizer_dis": optimizer_dis}
        # Task
        if self.settings.use_task_a or self.settings.use_task_b:
            back_params = filter(lambda p: p.requires_grad, self.task_backend.parameters())
            optimizer_back = radam.RAdam(back_params,
                                         lr=self.settings.lr_back,
                                         weight_decay=weight_decay,
                                         betas=(0., 0.999))
            self.optimizers_dict["optimizer_back"] = optimizer_back

        # Decoder Task
        if self.settings.use_decoder_a:
            decoder_sensor_a_params = filter(lambda p: p.requires_grad, self.decoder_sensor_a.parameters())
            optimizer_decoder_sensor_a = radam.RAdam(decoder_sensor_a_params,
                                                     lr=self.settings.lr_decoder,
                                                     weight_decay=weight_decay,
                                                     betas=(0., 0.999))

            self.optimizers_dict["optimizer_decoder_sensor_a"] = optimizer_decoder_sensor_a

        if self.settings.use_decoder_b:
            decoder_sensor_b_params = filter(lambda p: p.requires_grad, self.decoder_sensor_b.parameters())
            optimizer_decoder_sensor_b = radam.RAdam(decoder_sensor_b_params,
                                                     lr=self.settings.lr_decoder,
                                                     weight_decay=weight_decay,
                                                     betas=(0., 0.999))
            self.optimizers_dict["optimizer_decoder_sensor_b"] = optimizer_decoder_sensor_b

        # Refinement Task
        if self.settings.cross_refinement_a:
            refinement_a_params = filter(lambda p: p.requires_grad, self.cross_refinement_net_a.parameters())
            refinement_discr_a_params = filter(lambda p: p.requires_grad, self.refinement_discr_a.parameters())
            optimizer_refinement_a = radam.RAdam(refinement_a_params,
                                                 lr=self.settings.lr_decoder,
                                                 weight_decay=weight_decay,
                                                 betas=(0., 0.999))
            optimizer_refinement_discr_a = radam.RAdam(refinement_discr_a_params,
                                                       lr=self.settings.lr_discriminator,
                                                       weight_decay=weight_decay,
                                                       betas=(0., 0.999))
            self.optimizers_dict["optimizer_refinement_a"] = optimizer_refinement_a
            self.optimizers_dict["optimizer_refinement_discr_a"] = optimizer_refinement_discr_a

        if self.settings.cross_refinement_b:
            refinement_b_params = filter(lambda p: p.requires_grad, self.cross_refinement_net_b.parameters())
            refinement_discr_b_params = filter(lambda p: p.requires_grad, self.refinement_discr_b.parameters())
            optimizer_refinement_b = radam.RAdam(refinement_b_params,
                                                 lr=self.settings.lr_decoder,
                                                 weight_decay=weight_decay,
                                                 betas=(0., 0.999))
            optimizer_refinement_discr_b = radam.RAdam(refinement_discr_b_params,
                                                       lr=self.settings.lr_discriminator,
                                                       weight_decay=weight_decay,
                                                       betas=(0., 0.999))
            self.optimizers_dict["optimizer_refinement_b"] = optimizer_refinement_b
            self.optimizers_dict["optimizer_refinement_discr_b"] = optimizer_refinement_discr_b

    def train_step(self, input_batch,clas_good_enough = False):
        if not input_batch or input_batch[0][1].shape[0] == 0:
            print('Empty Labels  %s' % input_batch[0][1].shape[0])
            return {}, {}
        # alternate between gen loss and dis loss
        mod_step = self.step_count % (self.settings.front_iter + self.settings.disc_iter)

        if mod_step < self.settings.disc_iter:
            # Discriminator Step
            optimizers_list = ['optimizer_dis']
            if self.settings.cross_refinement_a:
                optimizers_list.append('optimizer_refinement_discr_a')
            if self.settings.cross_refinement_b:
                optimizers_list.append('optimizer_refinement_discr_b')
            
            norm_loss = 0
            for key_word in optimizers_list:
                optimizer_key_word = self.optimizers_dict[key_word]
                optimizer_key_word.zero_grad()
            

            d_final_loss= self.discriminator_train_step(input_batch)
            d_final_loss.backward()
            
            for i in ['dis',"refinement_discr_sensor_b"]:
                torch.nn.utils.clip_grad_norm_(self.models_dict[i].parameters(), 1, norm_type='inf', error_if_nonfinite=False)

                norm_loss += self.orthogonal_regularization(self.models_dict[i], self.device, 0.5)
            norm_loss.backward()
            for key_word in optimizers_list:

                optimizer_key_word = self.optimizers_dict[key_word]
                optimizer_key_word.step()

            return
        else:
            # Front End Step
            if not clas_good_enough:
                optimizers_list = ['optimizer_front_sensor_a', 'optimizer_front_sensor_b', 'optimizer_front_shared']
            else:
                optimizers_list = ['optimizer_front_sensor_b']
        
            if self.settings.use_decoder_a:
                optimizers_list.append('optimizer_decoder_sensor_a')
            if self.settings.use_decoder_b:
                optimizers_list.append('optimizer_decoder_sensor_b')
            if self.settings.cross_refinement_a:
                optimizers_list.append('optimizer_refinement_a')
            if self.settings.cross_refinement_b:
                optimizers_list.append('optimizer_refinement_b')
            if self.settings.use_task_a or self.settings.use_task_b:
                if not clas_good_enough:
                    optimizers_list.append('optimizer_back')

            for key_word in optimizers_list:
                optimizer_key_word = self.optimizers_dict[key_word]
                optimizer_key_word.zero_grad()
            
            g_final_loss= self.generator_train_step(input_batch)
            

            g_final_loss.backward()
            for key_word in optimizers_list:

                optimizer_key_word = self.optimizers_dict[key_word]
                optimizer_key_word.step()

            return


    def generator_train_step(self, batch):
        data_a = torch.cat([batch[0][0], batch[0][2]], dim=0)
        labels_a = torch.cat([batch[0][1], batch[0][1]],dim = 0)
        data_b = torch.cat([batch[1][0],batch[1][2]], dim=0)
        labels_b = torch.cat([batch[1][1],batch[1][1]],dim = 0)

        # Set BatchNorm Statistics to Train
        for model in self.models_dict:
            self.models_dict[model].train()

        gen_model_sensor_a = self.models_dict['front_sensor_a']
        gen_model_sensor_b = self.models_dict['front_sensor_b']


        g_loss = 0.
        # Train generator.
        # Generator output.
        (content_sensor_a1,content_sensor_a2), mu_sensor_a, logvar_sensor_a, attribute_sensor_a = gen_model_sensor_a(data_a)
        (content_sensor_b1,content_sensor_b2), mu_sensor_b, logvar_sensor_b, attribute_sensor_b = gen_model_sensor_b(data_b)
        
        content_sensor_a = content_sensor_a1+content_sensor_a2
        content_sensor_b = content_sensor_b1+content_sensor_b2
        
        
        #for events 
        #only adaptive pooling
        event_projector = nn.AdaptiveAvgPool2d((1,1))
        event_vector = event_projector(attribute_sensor_b).squeeze()
        content_sensor_b_projector = nn.AdaptiveAvgPool2d((1,1))
        content_sensor_b_vector = content_sensor_b_projector(content_sensor_b).squeeze()
        
        content_sensor_a_projector = nn.AdaptiveAvgPool2d((1,1))
        content_sensor_a_vector = content_sensor_a_projector(content_sensor_a).squeeze()
        similarity_content_sensor_a = nn.functional.cosine_similarity(content_sensor_a_vector[:self.settings.batch_size_a], content_sensor_a_vector[self.settings.batch_size_a:])
        similarity_content_sensor_b = nn.functional.cosine_similarity(content_sensor_b_vector[self.settings.batch_size_b:], content_sensor_b_vector[:self.settings.batch_size_b])
        similarity_event_vector = nn.functional.cosine_similarity(event_vector[self.settings.batch_size_b:], event_vector[:self.settings.batch_size_b])
        
        g_loss += torch.mean(-1*similarity_content_sensor_a) # we want similarities to be higher
        g_loss += torch.mean(-1*similarity_content_sensor_b) # we want similarities to be higher
        g_loss += torch.mean(-1*similarity_event_vector) # we want similarities to be higher

        similarity = torch.abs(nn.functional.cosine_similarity(event_vector, content_sensor_b_vector))

        g_loss += torch.mean(similarity)
        g_loss += self.trainDiscriminatorStep(content_sensor_a, content_sensor_b)

        if self.settings.use_task_a:
            self.taskloss = self.trainTaskStep('sensor_a', content_sensor_a, labels_a)
            g_loss += self.taskloss
        if self.settings.use_task_b:
            print("Label B")
            g_loss += self.trainTaskStep('sensor_b', content_sensor_b, labels_b)
        
        
        if self.settings.use_cycle_a_b: 
            out = self.trainCycleStep('sensor_a', 'sensor_b', content_sensor_a, mu_sensor_b,
                                      attribute_sensor_b, data_a, labels_a, self.settings.use_task_a,
                                      self.settings.cross_refinement_b,data_b)
            g_loss += out[0]
            cross_decoder_output_b = out[1]

        if self.settings.use_cycle_b_a:
            g_loss += self.trainCycleStep('sensor_b', 'sensor_a', content_sensor_b, mu_sensor_a,
                                          attribute_sensor_a, data_b, labels_b, self.settings.use_task_b,
                                          self.settings.cross_refinement_a)

        if self.settings.use_decoder_b and self.settings.sensor_b_name == 'events':
            g_loss += self.augmentFlowAttribute('sensor_b', cross_decoder_output_b, data_a, content_sensor_a,
                                                translation=True)

        return g_loss

    def trainDiscriminatorStep(self, content_sensor_a, content_sensor_b):
        #This is for generator
        dis_model = self.models_dict['dis']
        input_disc = torch.cat([content_sensor_a, content_sensor_b], dim=0)
        logits = dis_model(input_disc)
        logits_sensor_a = logits[:content_sensor_a.shape[0]]
        logits_sensor_b = logits[content_sensor_a.shape[0]:]
        # Compute GAN loss.
        discr_loss = Ra_generator_adversarial_loss(logits_sensor_a, logits_sensor_b)
        return discr_loss

    def trainCycleStep(self, sensor_name, second_sensor_name, content_first_sensor, attribute_mu_second_sensor,
                       attribute_second_sensor, data_first_sensor, labels_first_sensor,use_task_first_sensor,
                       cross_refinement_second_sensor,data_b):
        decoder_second_sensor = self.models_dict['decoder_' + second_sensor_name]
        gen_model_second_sensor = self.models_dict['front_' + second_sensor_name]

        decoder_output = decoder_second_sensor.forward(content_first_sensor, attribute_second_sensor)

        g_loss = 0
        if cross_refinement_second_sensor:
            cross_refinement_net = self.models_dict['cross_refinement_net_' + second_sensor_name]
            refinement_discr = self.models_dict['refinement_discr_' + second_sensor_name]

            out = cross_refinement_net.forward(decoder_output, data_first_sensor,
                                               return_clean_reconst=True, return_flow=True)
            reconst_second_sensor_input, clean_reconst, flow_map = out

            # Image Gradient Loss
            g_loss += self.trainImageGradientStep(data_first_sensor, clean_reconst, flow_map, second_sensor_name)

            # Flow Smoothness Loss
            smoothness_loss = self.flowSmoothnessLoss(flow_map) * self.settings.weight_smoothness_loss
            g_loss += smoothness_loss
           

            # Discriminator Step
            refinement_logits = refinement_discr(reconst_second_sensor_input)
            real_logits = refinement_discr(data_b)
            refined_generator_loss = Ra_generator_adversarial_loss(real_logits,refinement_logits)
            g_loss += refined_generator_loss
           

        (content_cycle1,content_cycle2), attribute_mu_cycle_second_sensor, _, _ = gen_model_second_sensor(reconst_second_sensor_input)
        content_cycle = content_cycle1+content_cycle2
        # Cycle Content Loss
        cycle_content_loss = self.cycle_content_loss(content_first_sensor, content_cycle) * \
                             self.settings.weight_cycle_loss
        g_loss += cycle_content_loss
        cycle_name = sensor_name + '_to_' + second_sensor_name
        

        # Cycle Attribute Loss
        cycle_attribute_loss = self.cycle_attribute_loss(attribute_mu_second_sensor,
                                                         attribute_mu_cycle_second_sensor) * \
                               self.settings.weight_cycle_loss
        g_loss += cycle_attribute_loss
        cycle_name = sensor_name + '_to_' + second_sensor_name
        

        if self.settings.use_cycle_task and use_task_first_sensor:
            if sensor_name == 'sensor_b':
                print("Label B")
            g_loss += self.trainCycleAccuracyStep(sensor_name, reconst_second_sensor_input, labels_first_sensor)

        if sensor_name == 'sensor_a':
            return g_loss, decoder_output
        return g_loss

    def trainCycleAccuracyStep(self, cycle_name, reconst_second_sensor_input, labels):
        task_backend = self.models_dict["back_end"]
        event_encoder = self.models_dict['front_sensor_b']

        (content_cycle1,content_cycle2), _, _, _ = event_encoder(reconst_second_sensor_input.detach())
        content_cycle = content_cycle1+content_cycle2
        pred_sensor_cycle = task_backend(content_cycle)

        loss_pred = self.task_loss(pred_sensor_cycle, target=labels) * 2
        
        return loss_pred

    def trainTaskStep(self, sensor_name, content_features, labels):
        task_backend = self.models_dict["back_end"]
        pred_sensor = task_backend(content_features)
        loss_pred = self.task_loss(pred_sensor, target=labels) * self.settings.weight_task_loss

        return loss_pred

    def trainImageGradientStep(self, data_first_sensor, clean_reconst, flow_map, second_sensor_name):
        norm_gradient = self.norm_gradient_layer.forward(data_first_sensor)
        torch_smoothed = self.gaussian_layer(norm_gradient)
        gradient_loss = 0

        summed_events = torch.sum(clean_reconst, dim=1, keepdim=True)

        # Positive Loss
        pos_loss_bool = torch_smoothed > 0.3
        pos_loss_spatial = nn.functional.relu(0.7 - summed_events) * norm_gradient
        pos_loss = pos_loss_spatial[pos_loss_bool].mean()

       
        gradient_loss += pos_loss * 5

        return gradient_loss

    def flowSmoothnessLoss(self, flow_map):
        """Computes the smoothness loss in a neighbour hood of 2 for each pixel based on the Charbonnier loss"""
        displ = flow_map
        displ_c = displ[..., 1:-1, 1:-1]

        displ_u = displ[..., 1:-1, 2:]
        displ_d = displ[..., 1:-1, 0:-2]
        displ_l = displ[..., 2:, 1:-1]
        displ_r = displ[..., 0:-2, 1:-1]

        displ_ul = displ[..., 2:, 2:]
        displ_dr = displ[..., 0:-2, 0:-2]
        displ_dl = displ[..., 0:-2, 2:]
        displ_ur = displ[..., 2:, 0:-2]

        loss = self.charbonnier_loss(displ_l - displ_c) +\
               self.charbonnier_loss(displ_r - displ_c) +\
               self.charbonnier_loss(displ_d - displ_c) +\
               self.charbonnier_loss(displ_u - displ_c) +\
               self.charbonnier_loss(displ_dl - displ_c) +\
               self.charbonnier_loss(displ_dr - displ_c) +\
               self.charbonnier_loss(displ_ul - displ_c) +\
               self.charbonnier_loss(displ_ur - displ_c)
        loss /= 8
        return loss

    def charbonnier_loss(self, delta, exponent=0.45, eps=1e-3):
        # alpha = 0.25
        # epsilon = 1e-8
        return (delta.pow(2) + eps**2).pow(exponent).mean()

    def augmentFlowAttribute(self, sensor_name, cross_decoder_output, data_first_sensor, content_first_sensor,
                             translation=False):
        flow_map = cross_decoder_output[:, :2, :, :]
        height, width = self.settings.img_size_b[0], self.settings.img_size_b[1]

        # --- Flow Augmentation
        if not translation:
            # Flow is in Camera Coordinates: (x, y) = (u. v)
            b = flow_map.shape[0]
            random_x_center = (torch.rand(b, device=flow_map.device) * width).long()
            random_y_center = (torch.rand(b, device=flow_map.device) * (height / 4) + (height - height / 4) / 2).long()
            x_direction = (torch.arange(width, device=flow_map.device)[None, :] - random_x_center[:, None]) / (width / 2)
            y_direction = (torch.arange(height, device=flow_map.device)[None, :] - random_y_center[:, None]) / (height / 2)

        # --- Translation
        else:
            x_direction = (torch.rand([int(2*self.settings.batch_size_a), 1], device=flow_map.device) * 2 - 1).repeat([1, width])
            y_direction = (torch.rand([int(2*self.settings.batch_size_a), 1], device=flow_map.device) * 2 - 1).repeat([1, height])

        augmented_flow_vectors = torch.stack([x_direction[:, None, :].repeat([1, height, 1]),
                                              y_direction[:, :, None].repeat([1, 1, width])], dim=1)

        augmented_flow_vectors = nn.functional.normalize(augmented_flow_vectors, p=2, dim=1)

        pred_flow_magnitude = torch.sqrt((flow_map ** 2).sum(1, keepdim=True))
        augmented_flow = augmented_flow_vectors * pred_flow_magnitude
        # ------

        front_end_sensor_b = self.models_dict['front_sensor_b']
        augmented_input = torch.cat([augmented_flow, cross_decoder_output[:, 2:, :, :]], dim=1)

        cross_refinement_net = self.models_dict['cross_refinement_net_' + sensor_name]
        augmented_event_histo = cross_refinement_net.forward(augmented_input.detach(), data_first_sensor)
        augmented_attribute_f, _, _ = front_end_sensor_b(augmented_event_histo, attribute_only=True)

        decoder_network = self.models_dict['decoder_' + sensor_name]
        augmented_decoder_output = decoder_network.forward(content_first_sensor, augmented_attribute_f)

        reconst_loss = self.reconst_loss_dict[sensor_name](augmented_decoder_output[:, :2, :, :], augmented_flow) * \
                                self.settings.weight_reconstruction_sensor_b_loss


        if self.visualize_epoch():
            nrow = 4
            viz_org_flow = viz_utils.visualizeFlow(flow_map[:nrow])
            viz_flow_augmented = viz_utils.visualizeFlow(augmented_flow[:nrow])
            viz_flow_predicted = viz_utils.visualizeFlow(augmented_decoder_output[:nrow, :2, :, :])

            viz_tensors = torch.cat((data_first_sensor[:nrow].expand(-1, 3, -1, -1),
                                     viz_org_flow,
                                     viz_flow_augmented,
                                     viz_flow_predicted,
                                     viz_utils.createRGBImage(augmented_event_histo[:nrow])), dim=0)
            rgb_grid = torchvision.utils.make_grid(viz_tensors, nrow=nrow)
            self.img_summaries('train/flow_augmentation' + sensor_name + '_img', rgb_grid, self.step_count)

        return reconst_loss

    def discriminator_train_step(self, batch):
        data_a = torch.cat([batch[0][0], batch[0][2]], dim=0)
        data_b = torch.cat([batch[1][0], batch[1][2]], dim=0)
        gen_model_sensor_a = self.models_dict['front_sensor_a']
        gen_model_sensor_b = self.models_dict['front_sensor_b']
        dis_model = self.models_dict['dis']

        with torch.no_grad():
            (content_f_sensor_a1,content_f_sensor_a2), _, _, attribute_f_sensor_a = gen_model_sensor_a(data_a)
            (content_f_sensor_b1,content_f_sensor_b2), _, _, attribute_f_sensor_b = gen_model_sensor_b(data_b)
        
        content_f_sensor_a = content_f_sensor_a1+content_f_sensor_a2
        content_f_sensor_b = content_f_sensor_b1+content_f_sensor_b2
                
        input_disc = torch.cat([content_f_sensor_a, content_f_sensor_b], dim=0)       
        logits = dis_model(input_disc) 
        logits_sensor_a = logits[:int(2*self.settings.batch_size_a)]
        logits_sensor_b = logits[int(2*self.settings.batch_size_a):]

        d_loss = Ra_discriminator_loss(logits_sensor_a, logits_sensor_b)

        if self.settings.cross_refinement_a:
            d_loss += self.discriminatorCrossStep(data_a, data_b, content_f_sensor_b, attribute_f_sensor_a,
                                                  'sensor_a')

        if self.settings.cross_refinement_b:
            d_loss += self.discriminatorCrossStep(data_b, data_a, content_f_sensor_a, attribute_f_sensor_b,
                                                  'sensor_b')
       

        return d_loss 

    def discriminatorCrossStep(self, data_first_sensor, data_second_sensor, content_second_sensor,attribute_first_sensor, first_sensor):
        with torch.no_grad():
            decoder_sensor_first = self.models_dict['decoder_' + first_sensor]
            reconst_sensor_second_to_first = decoder_sensor_first.forward(content_second_sensor, attribute_first_sensor)

            cross_refinement_net_a = self.models_dict['cross_refinement_net_' + first_sensor]
            refined_sensor_b_a = cross_refinement_net_a.forward(reconst_sensor_second_to_first, data_second_sensor)

        refinement_discr_a = self.models_dict['refinement_discr_' + first_sensor]
        fake_logits = refinement_discr_a(refined_sensor_b_a.detach())
        real_logits = refinement_discr_a(data_first_sensor)
        refine_discr_loss = Ra_discriminator_loss(real_logits,fake_logits)

        return refine_discr_loss
 #cite https://blog.csdn.net/qq_27261889/article/details/87706903
    def orthogonal_regularization(self,model, device, beta):
        r"""
            author: Xu Mingle
            time: 2019,2,19  15:12:43
            input:
                model: which is the model we want to use orthogonal regularization, e.g. Generator or Discriminator
                device: cpu or gpu
                beta: hyperparameter
            output: loss
        """


        loss_orth = torch.tensor(0., dtype=torch.float32, device=device)

        for name, param in model.named_parameters():
            if 'weight' in name and param.requires_grad and len(param.shape)==4:

                N, C, H, W = param.shape

                weight = param.view(N * C, H, W)

                weight_squared = torch.bmm(weight, weight.permute(0, 2, 1)) # (N * C) * H * H

                ones = torch.ones(N * C, H, H, dtype=torch.float32) # (N * C) * H * H

                diag = torch.eye(H, dtype=torch.float32) # (N * C) * H * H

                loss_orth += ((weight_squared * (ones - diag).to(device)) ** 2).sum()


        return loss_orth * beta

