# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Loss functions used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import torch
from torch_utils import persistence

#----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VPLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------
# Loss function corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VELoss:
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# Discriminator Loss function corresponding to the variance preserving (VP) formulation

@persistence.persistent_class
class VPDiscriminatorLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def __call__(self, net, images, labels, real_or_not, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        real_prediction = net(y + n, sigma, labels, augment_labels=augment_labels)
        # loss = torch.nn.BCELoss(weight=weight, reduction='none')(real_prediction, real_or_not)
        loss = torch.nn.BCELoss(reduction='none')(real_prediction, real_or_not)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------
# Discriminator Loss function corresponding to the Improved loss function proposed in
# the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMDiscriminatorLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, noise_type='none', noise_rate=0., num_classes=10, dataset='cifar10'):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.noise_type = noise_type
        self.noise_rate = noise_rate
        self.num_classes = num_classes
        self.dataset = dataset

    def __call__(self, net, images, labels, real_or_not, diversity=False, score=None, loss_type="distance", augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        # perturbed_data = torch.tensor(y + n, requires_grad=True)
        perturbed_data = (y + n).clone().detach().requires_grad_(True)
        real_prediction = net(perturbed_data, sigma, labels, augment_labels=augment_labels)
        # loss = torch.nn.BCELoss(weight=weight, reduction='none')(real_prediction, real_or_not)
        loss = torch.nn.BCELoss(reduction='none')(real_prediction, real_or_not)

        # Calculate the diversity losses
        if diversity:
            sum_exp_prob_class = torch.exp(real_prediction) + torch.exp(1 - real_prediction)
            class_labels_idx = torch.argmax(labels, dim=1)
            # Calculate the edsm loss
            with torch.no_grad():
                for c in range(labels.shape[1]):
                    labels_temp = torch.nn.functional.one_hot(
                        c * torch.ones(labels.shape[0], dtype=int).to(images.device), labels.shape[1])
                    # candidate = (class_labels_idx != c) & (w_x_yt[:, c] > self.tau)
                    candidate = (class_labels_idx != c)
                    if candidate.sum() != 0:
                        D_yn = net(perturbed_data[candidate], sigma[candidate], labels_temp[candidate],
                                   augment_labels=augment_labels[candidate])
                        sum_exp_prob_class[candidate] += torch.exp(D_yn) + torch.exp(1 - D_yn) 
            log_sum_exp_prob_class = torch.log(sum_exp_prob_class + 1e-8)
            grad_log_sum_exp_prob_class, = torch.autograd.grad(log_sum_exp_prob_class.sum(), perturbed_data, create_graph=True)

            lambda_edsm = 1.0
            loss_edsm = lambda_edsm * (sigma * grad_log_sum_exp_prob_class + n / sigma)**2

            return loss, loss_edsm, real_prediction
        
        if score:
            # num_reals = real_or_not.sum()
            # clean_idx = torch.randperm(num_reals)
            # clean_idx = clean_idx if num_reals < y.shape[0] - clean_idx else clean_idx[:y.shape[0] - clean_idx]
            real_or_not = real_or_not.squeeze().bool()
            real_prediction = torch.clip(real_prediction, 1e-5, 1. - 1e-5)
            log_ratio = torch.log(real_prediction / (1. - real_prediction))
            if loss_type == "distance":
                candidate = (sigma > 1).squeeze() & (~real_or_not)
            elif loss_type == "direction":
                candidate = ~real_or_not
            else:
                raise f"Loss type {loss_type} is not supported."
            with torch.no_grad():
                D_yn = score((y + n)[candidate], sigma[candidate], labels[candidate], augment_labels=augment_labels[candidate])
            grad_log_ratio = torch.autograd.grad(outputs=log_ratio.sum(), inputs=perturbed_data, create_graph=True)[0][candidate]
            
            lambda_tdsm = 1.0
            if loss_type == "distance":
                loss_tdsm = lambda_tdsm * (grad_log_ratio + (D_yn - y[candidate]) / (sigma[candidate]**2)) ** 2 / sum(candidate) * len(candidate)
            else:
                loss_tdsm = lambda_tdsm * (1 - torch.nn.functional.cosine_similarity(grad_log_ratio.view(labels.shape[0], -1), (y[candidate] - D_yn).view(labels.shape[0], -1))) / sum(candidate) * len(candidate)
            
            return loss, loss_tdsm, real_prediction
        
        return loss, None, real_prediction

