import torch


def get_grad_log_ratio(discriminator, x, sigma, img_resolution, S_clip_min, S_clip_max, labels, logit=False):
    # mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier
    if sigma > S_clip_max or sigma < S_clip_min:
        if logit:
          return torch.zeros_like(x), torch.ones(x.shape[0], device=x.device)
        return torch.zeros_like(x)

    with torch.enable_grad():
        x_ = x.float().clone().detach().requires_grad_()

        sigma_ = torch.ones(x.shape[0], device=x.device) * sigma

        log_ratio, pred = get_log_ratio(discriminator, x_, sigma_, labels)
        discriminator_guidance_score = torch.autograd.grad(outputs=log_ratio.sum(), inputs=x_, retain_graph=False)[0]
        # print(mean_vp_tau.shape)
        # print(std_wve_t.shape)
        # print(discriminator_guidance_score.shape)
        discriminator_guidance_score *= - (sigma ** 2)
    if logit:
      return discriminator_guidance_score, pred
    return discriminator_guidance_score


def get_log_ratio(discriminator, input, sigma, labels):
    if discriminator == None:
        return torch.zeros(input.shape[0], device=input.device), None
    else:
        logits = discriminator(input, sigma, labels)
        prediction = torch.clip(logits, 1e-5, 1. - 1e-5)
        log_ratio = torch.log(prediction / (1. - prediction))
        return log_ratio, prediction

