from torch import nn
import torch
from torch.nn import functional as F
from torchvision.ops import DeformConv2d

class DeformConv(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1):
        super().__init__()
        if isinstance(kernel_size,int):
            kernel_size=(kernel_size,kernel_size)
        offset_out_channels=2*deformable_groups*kernel_size[0]*kernel_size[1]
        mask_out_channels=deformable_groups*kernel_size[0]*kernel_size[1]
        self.deform_conv=DeformConv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,bias=True)
        self.offset_conv=nn.Conv2d(in_channels,offset_out_channels,kernel_size,stride,padding,bias=True)
        self.mask_conv=nn.Conv2d(in_channels,mask_out_channels,kernel_size,stride,padding,bias=True)
        self.offset_conv.weight.data.zero_()
        self.offset_conv.bias.data.zero_()
        self.mask_conv.weight.data.zero_()
        self.mask_conv.bias.data.zero_()
    def forward(self,x):
        x,x2=x
        offset=self.offset_conv(x2)
        mask=self.mask_conv(x2)
        mask = torch.sigmoid(mask)
        x=self.deform_conv(x,offset,mask)
        return x
def steep_transform(r):
    with torch.no_grad():
        a=r%1
        b=r-a
        if a<0.45:
            slope=0.2/0.45
            bias=-b*slope
        elif a<0.55:
            slope=0.6/0.1
            bias=0.2-(0.45+b)*slope
        else:
            slope=0.2/0.45
            bias=0.8-(0.55+b)*slope
    return slope*r+bias+b
class round_transform(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.round(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output
class LearnableDilatedConvHelper(nn.Module):
    def __init__(self,in_channels, out_channels, stride,groups,initial_dilation_rate,transform):
        super().__init__()
        kernel_size=3
        padding=1
        self.dilation_rate=torch.nn.Parameter(
            torch.tensor(initial_dilation_rate,dtype=torch.float32)
        )
        self.deform_conv=DeformConv2d(in_channels,out_channels,kernel_size,
                                      stride,padding,groups=groups,bias=False)
        self.stride=stride
        if transform=="none":
            self.transform=None
        elif transform=="steep":
            self.transform=steep_transform
        elif transform=="round":
            self.transform=round_transform.apply
        else:
            raise NotImplementedError()
    def get_dilation_rate(self):
        if self.transform is None:
            return self.dilation_rate
        return self.transform(self.dilation_rate)
    def forward(self,x):
        offset=self.generate_offset(x)
        x=self.deform_conv(x,offset)
        return x
    def generate_offset(self,x):
        N,C,H,W=x.shape
        outH=(H-1)//self.stride+1
        outW=(W-1)//self.stride+1
        r=self.get_dilation_rate()-1
        # offset_out_channels=2*kernel_size[0]*kernel_size[1]
        # offset=torch.tensor([-r, -r, -r, 0, -r, r, 0, -r,
        #  0, 0, 0, r, r, -r, r, 0, r, r])
        offset=torch.zeros(18,device=x.device)
        positive_indices=torch.tensor([5, 11, 12, 14, 16, 17],device=x.device)
        negative_indices=torch.tensor([0, 1, 2, 4, 7, 13],device=x.device)
        offset[positive_indices]=r
        offset[negative_indices]=-r
        offset=offset.reshape(1,18,1,1).repeat(N,1,outH,outW)
        return offset


class FeatureAlign_V2(nn.Module):  # FaPN full version
    def __init__(self, in_nc=128, out_nc=128):
        super(FeatureAlign_V2, self).__init__()
        self.lateral_conv = FeatureSelectionModule(in_nc, out_nc)
        self.project=ConvBnAct(out_nc * 2, out_nc,apply_act=False)
        self.dcpack_L2 = DeformConv(out_nc, out_nc, 3, stride=1, padding=1, dilation=1, deformable_groups=8)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, feat_l, feat_s):
        feat_up = F.interpolate(feat_s, feat_l.shape[-2:], mode='bilinear', align_corners=False)
        feat_arm = self.lateral_conv(feat_l)  # 0~1 * feats
        offset = self.project(torch.cat([feat_arm, feat_up * 2], dim=1))  # concat for offset by compute the dif
        feat_align = self.relu(self.dcpack_L2([feat_up, offset]))
        return feat_align, feat_arm

class AlignedModule(nn.Module):
    #SFNet-DFNet
    def __init__(self, inplane, outplane):
        super(AlignedModule, self).__init__()
        self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)
        self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)
        self.flow_make = nn.Conv2d(outplane*2, 2, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        low_feature, h_feature = x
        h_feature_orign = h_feature
        h, w = low_feature.size()[2:]
        size = (h, w)
        low_feature = self.down_l(low_feature)
        h_feature= self.down_h(h_feature)
        h_feature = F.interpolate(h_feature,size=size,mode="bilinear",align_corners=False)
        flow = self.flow_make(torch.cat([h_feature, low_feature], 1))
        h_feature = self.flow_warp(h_feature_orign, flow, size=size)

        return h_feature

    def flow_warp(self, input, flow, size):
        out_h, out_w = size
        n, c, h, w = input.size()
        # n, c, h, w
        # n, 2, h, w

        norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)
        h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + flow.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid,align_corners=False)
        return output

def forward_check():
    torch.manual_seed(0)
    conv1=LearnableDilatedConvHelper(10,10,2,2,2,"none")
    torch.manual_seed(0)
    conv2=nn.Conv2d(10,10,3,2,2,bias=False,dilation=2,groups=2)
    x=torch.randn(2,10,30,30)
    y1=conv1(x)
    y2=conv2(x)
    print(torch.allclose(y1,y2))
    torch.mean(y1).backward()
    torch.mean(y2).backward()
    print(torch.allclose(conv2.weight.grad, conv1.deform_conv.weight.grad))
def backward_check():
    torch.manual_seed(0)
    conv1=LearnableDilatedConvHelper(2,2,1,1,1.5,"round")
    x=torch.randn(2,2,3,3)
    offset=conv1.generate_offset(x)
    # torch.mean(offset).backward()
    y=conv1(x)
    torch.mean(y).backward()
    for name,p in conv1.named_parameters():
        print(name)
        print(p)
        if p.grad is not None:
            print(p.grad)
def backward_check2():
    dilation_rate=torch.tensor(2.0,requires_grad=True)
    offset=torch.zeros(18)
    r=dilation_rate-1
    #offset_out_channels=2*kernel_size[0]*kernel_size[1]
    # offset2=torch.tensor([-r, -r, -r, 0, -r, r, 0, -r,
    #                       0, 0, 0, r, r, -r, r, 0, r, r],requires_grad=True)
    # print((offset2==1).nonzero().reshape(-1).tolist())
    offset[torch.tensor([0, 1, 2, 4, 7, 13])]=-r
    offset[torch.tensor([5, 11, 12, 14, 16, 17])]=r
    print(offset)
    print(dilation_rate.grad)
    torch.mean(offset-1).backward()
    print(dilation_rate.grad)
if __name__=="__main__":
    # forward_check()
    backward_check()
