import torch
from torch import autograd

class Sparse_NHWC(autograd.Function):
    """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase"""

    @staticmethod
    def forward(ctx, weight, Mask, decay=0.0002):
        w_b = torch.round(Mask)
        ctx.save_for_backward(weight)
        output = weight.clone()
        ctx.mask = w_b
        ctx.decay = decay

        return output*w_b

    @staticmethod
    def backward(ctx, grad_output):
        weight, = ctx.saved_tensors
        return grad_output + ctx.decay * (1-ctx.mask) * weight, grad_output*weight

class ste(autograd.Function):
    """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase"""

    @staticmethod
    def forward(ctx, weight):
        output = torch.round(weight)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class MaskGivenThreshold(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, threshold):
        return torch.round(torch.clamp(scores / (threshold*2), 0, 1))

    @staticmethod
    def backward(ctx, g):
        return g, None

class BinarySoftActivation(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, decay=0.002):
        mask = (input == input.max(dim=1, keepdim=True)[0]).float()
        ctx.save_for_backward(input, mask)
        return mask

    @staticmethod
    def backward(ctx, grad_output, decay=0.0002):
        weight, mask,  = ctx.saved_tensors
        return grad_output + decay * mask * weight


def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.data.view(-1).kthvalue(k).values.item()


def mask_given_sparsity(scores, sparsity, k_val=1e-5):
    if sparsity == 0.:
        k_val = torch.min(scores) * 0.99
    else:
        k_val = percentile(scores, sparsity * 100)
    return MaskGivenThreshold.apply(scores, k_val)

def mask_given_per_sparsity(scores, sparsity):
    k_val_list = []
    if sparsity == 0.:
        for i in range(scores.size(0)):
            k_val = torch.min(scores[i]) * 0.99
            k_val_list.append(k_val)
    else:
        for i in range(scores.size(0)):
            k_val = percentile(scores[i], sparsity * 100)
            k_val_list.append(k_val)

    return PerMaskGivenThreshold.apply(scores, k_val_list)

class PerMaskGivenThreshold(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, threshold):
        masked = scores.clone()
        for i in range(len(threshold)):
            masked[i] = torch.round(torch.clamp(scores[i] / (threshold[i]*2), 0, 1))
        return masked

    @staticmethod
    def backward(ctx, g):
        return g, None

def sparseFunction(x, s, activation=torch.relu, f=torch.sigmoid):
    return torch.sign(x)*activation(torch.abs(x)-f(s))

def initialize_sInit():
    return -3200*torch.ones([1, 1])
