from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import Function


class FuzzyClassificationLoss(nn.Module):
    """Fuzzy classification loss"""

    def __init__(self, weight: Optional[Tensor] = None, reduction: str = 'mean', softmax: bool = False):
        super(FuzzyClassificationLoss, self).__init__()
        if weight is not None and weight.min() < 0:
            raise ValueError('"weight" should greater than or equal to 0.')
        self.weight = weight.unsqueeze(-1).unsqueeze(-1) if weight is not None else weight
        self.reduction = reduction
        self.softmax = softmax

    def forward(self, input: Tensor) -> Tensor:
        if self.softmax:
            input = F.softmax(input, dim=1)
        if self.weight is None:
            pixel_loss = -(1 - torch.prod(1 - input, dim=1)).log()
        else:
            pixel_loss = -(1 - torch.prod((1-input).pow(self.weight), dim=1)).log()
        if self.reduction == 'mean':
            loss = pixel_loss.mean(dim=(0, 1, 2))
        elif self.reduction == 'sum':
            loss = pixel_loss.mean(dim=(1, 2)).sum(dim=0)
        else:
            loss = pixel_loss.mean(dim=(1, 2))
        return loss


class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) / (self.union * self.union)
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)


class FocalLoss(nn.Module):

    def __init__(self,
                 alpha=0.25,
                 gamma=2,
                 reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, label):
        '''
        Usage is same as nn.BCEWithLogits:
            >>> criteria = FocalLoss()
            >>> logits = torch.randn(8, 19, 384, 384)
            >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
            >>> loss = criteria(logits, lbs)
        '''
        probs = torch.sigmoid(logits)
        coeff = torch.abs(label - probs).pow(self.gamma).neg()
        log_probs = torch.where(logits >= 0,
                F.softplus(logits, -1, 50),
                logits - F.softplus(logits, 1, 50))
        log_1_probs = torch.where(logits >= 0,
                -logits + F.softplus(logits, -1, 50),
                -F.softplus(logits, 1, 50))
        loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
        loss = loss * coeff

        if self.reduction == 'mean':
            loss = loss.mean()
        if self.reduction == 'sum':
            loss = loss.sum()
        return loss


def IoU(input: Tensor, target: Tensor) -> Tensor:
    """Intersection over Union"""

    assert input.shape == target.shape, 'input and target should have same shape, ' \
        f'but shape of input is {input.shape} while target is {target.shape}'
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    eps = 0.0001
    for i, c in enumerate(zip(input, target)):
        intersection = torch.dot(c[0].view(-1), c[1].view(-1))
        union = torch.maximum(c[0], c[1]).sum()
        s += (intersection+eps) / (union+eps)

    return s / (i + 1)


class GHMLossBase(nn.Module):
    def __init__(self, bins, alpha):
        super(GHMLossBase, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        g = torch.abs(self._custom_loss_grad(x, target)).detach()

        bin_idx = self._g2bin(g)

        bin_count = torch.zeros((self._bins))
        for i in range(self._bins):
            bin_count[i] = (bin_idx == i).sum().item()

        N = (x.size(0) * x.size(1))

        if self._last_bin_count is None:
            self._last_bin_count = bin_count
        else:
            bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
            self._last_bin_count = bin_count

        nonempty_bins = (bin_count > 0).sum().item()

        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=0.0001)
        beta = N / gd
        
        return self._custom_loss(x, target, beta[bin_idx].to(target.device))


class GHMCELoss(GHMLossBase):
    def __init__(self, bins=10, alpha=0.5):
        super(GHMCELoss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        return (F.cross_entropy(x, target, reduction='none')*weight).mean()

    def _custom_loss_grad(self, x, target):
        return (1 - F.softmax(x, dim=1).gather(1, target.unsqueeze(1))).squeeze(1)