import torch
import torch.nn as nn

class AuxLoss(nn.Module):
    def __init__(self,loss_func1,loss_func2,aux_weight):
        super().__init__()
        self.loss_func1=loss_func1
        self.loss_func2=loss_func2
        self.aux_weight=aux_weight

    def forward(self, logits, labels):
        logits1,logits2=logits
        loss1=self.loss_func1(logits1,labels)
        loss2=self.loss_func1(logits2,labels)
        loss=loss1+self.aux_weight*loss2
        return loss

class NormalizedCE(nn.Module):
    def __init__(self,ignore_index):
        super().__init__()
        self.criterion=nn.CrossEntropyLoss(ignore_index=ignore_index)

    def forward(self, logits, labels):
        loss=self.criterion(logits, labels)
        normalized_loss=loss/loss.item()
        return normalized_loss

class BootstrappedCE(nn.Module):
    def __init__(self, min_K, loss_th, ignore_index):
        super().__init__()
        self.K = min_K
        self.threshold = loss_th
        self.criterion = nn.CrossEntropyLoss(
            ignore_index=ignore_index, reduction="none"
        )

    def forward(self, logits, labels):
        pixel_losses = self.criterion(logits, labels).contiguous().view(-1)

        mask=(pixel_losses>self.threshold)
        if torch.sum(mask).item()>self.K:
            pixel_losses=pixel_losses[mask]
        else:
            pixel_losses, _ = torch.topk(pixel_losses, self.K)
        return pixel_losses.mean()

class WeightedCE(nn.Module):
    # A better weighted CE loss than the traditional one.
    # It scales the logits by the class weights so that
    # every class can contribute around the same gradient size during backprop.
    # Forward pass will be slighly weird since the logits
    # scale linearly with the class weights,
    # but it shouldn't make a big difference after the softmax layer(CE)
    def __init__(self, ignore_label, weight):
        super(WeightedCE, self).__init__()
        self.ignore_label = ignore_label
        self.weight=weight.view(1,-1,1,1)
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label)

    def forward(self, logits, labels):
        logits=logits*self.weight
        return self.criterion(logits, labels)

class WeightedCE2(nn.Module):
    # symmetric weighted ce
    # accounts for both false positives and false negatives
    def __init__(self, ignore_label, weight):
        super(WeightedCE2, self).__init__()
        self.ignore_label = ignore_label
        self.weight=weight.view(-1)
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label,reduction="none")

    def forward(self, logits, labels):
        mask = labels != self.ignore_label
        losses=self.criterion(logits, labels)[mask]
        with torch.no_grad():
            gt_pixel_weights=self.weight[labels[mask]]
            probs=torch.softmax(logits,dim=1)
            pred_pixel_weights=torch.sum(self.weight.view(1,-1,1,1)*probs,dim=1)[mask]
            pixel_weights=torch.maximum(gt_pixel_weights,pred_pixel_weights)
        loss=torch.mean(losses*pixel_weights)
        return loss

def to_one_hot(target,num_classes,ignore_index):
    one_hot=torch.zeros(target.shape[0], num_classes+1, target.shape[1], target.shape[2], device=target.device)
    target2=torch.unsqueeze(target,1)
    target2[target2==ignore_index]=num_classes
    one_hot.scatter_(1, target2, 1)
    return one_hot
class DownsampledCE(nn.Module):
    def __init__(self,ignore_index,normalize=False):
        super().__init__()
        self.ignore_index=ignore_index
        self.criterion = nn.CrossEntropyLoss(reduction="none")
        self.normalize=normalize

    def forward(self, logits, target):
        #be ware of changing target
        _,num_classes,W,H=logits.shape
        one_hot=torch.zeros(target.shape[0], num_classes+1, target.shape[1], target.shape[2], device=target.device)
        target2=torch.clone(target)
        target2=torch.unsqueeze(target2,1)
        target2[target2==self.ignore_index]=num_classes
        one_hot.scatter_(1, target2, 1)
        one_hot=nn.functional.adaptive_avg_pool2d(one_hot,(W,H))
        mask=one_hot[:,num_classes,:,:]==0
        one_hot=one_hot[:,:num_classes,:,:]
        pixel_losses = self.criterion(logits, one_hot)
        pixel_losses=pixel_losses[mask]
        loss=pixel_losses.mean()
        if self.normalize:
            loss=loss/loss.item()
        return loss


def test_class_inbalance():
    # loss_fun=NormalizedCE2(255)
    loss_fun=WeightedCE2(255,torch.tensor([2,1,1]))
    output=torch.zeros((1,3,3,3),requires_grad=True)
    target=torch.zeros((1,3,3),dtype=torch.int64)
    target[:,:3,:2]=255
    loss=loss_fun(output,target)
    print(loss.shape)
    print(loss)
    torch.sum(loss).backward()
    print(output.grad)
    grad=torch.abs(output.grad).sum(dim=(2,3))
    print(grad)
def test_downsampled_ce():
    batch_size=2
    crop_h,crop_w=100,100
    num_classes=19
    logits=torch.randn(batch_size, 19, crop_h,crop_w,requires_grad=True)
    target=torch.randint(0,num_classes,(batch_size, crop_h,crop_w))
    loss_fun=DownsampledCE(255)
    loss=loss_fun(logits,target)
    loss.backward()
    print(logits.grad)
if __name__=='__main__':
    test_downsampled_ce()
