
import torch
import torch.nn.functional as F


def squared_error(input, target):
    return torch.sum((input - target) ** 2)

#Code adapted from: https://blog.csdn.net/RSstudent/article/details/109078141 with modification
def Ra_discriminator_loss(real_output, fake_output):
    Ra_loss_rf = F.sigmoid((real_output) - torch.mean(fake_output))
    Ra_loss_fr = F.sigmoid((fake_output) - torch.mean(real_output))
    L_Ra_d = - torch.mean(torch.log(Ra_loss_rf)) - torch.mean(torch.log(1- Ra_loss_fr))
    return L_Ra_d

def Ra_generator_adversarial_loss(real_output, fake_output):
    Ra_loss_rf = F.sigmoid((real_output) - torch.mean(fake_output))
    Ra_loss_fr = F.sigmoid((fake_output) - torch.mean(real_output))
    L_Ra_g = - torch.mean(torch.log(1 - Ra_loss_rf)) - torch.mean(torch.log(Ra_loss_fr))
    return L_Ra_g


#Code adapted from: https://github.com/uzh-rpg/rpg_ev-transfer  with modification
def event_reconstruction_loss(gt_histogram, predicted_histogram):
    l1_distance = torch.abs(gt_histogram - predicted_histogram).sum(dim=1)
    bool_zero_cells = gt_histogram.sum(dim=1) > 0

    if torch.logical_not(bool_zero_cells).sum() == 0 or bool_zero_cells.sum() == 0:
        return l1_distance.mean()

    return l1_distance[bool_zero_cells].mean() + l1_distance[torch.logical_not(bool_zero_cells)].mean()




def normalize_l2(x):
    return torch.nn.functional.normalize(x.flatten(start_dim=-2), p=2, dim=1).reshape_as(x)
