import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import numpy as np


def gaussian(sigma,sz):
    x, y = np.meshgrid(np.arange(0, sz[0]) - np.floor(float(sz[0]) / 2), np.arange(0, sz[1]) - np.floor(float(sz[1]) / 2))
    d = x ** 2 + y ** 2
    g = np.exp(-0.5 / (sigma ** 2) * d)
    return g.astype(np.float32)


class Correlation(nn.Module):
    def __init__(self,w,w_up,w_down,w_50,w_30,w_max,w_min):
        super(Correlation,self).__init__()

        wp = w.permute(2,0,1).unsqueeze(0)
        wp_up = w_up.permute(2,0,1).unsqueeze(0)
        wp_down = w_down.permute(2,0,1).unsqueeze(0)
        wp_50 = w_50.permute(2,0,1).unsqueeze(0)
        wp_30 = w_30.permute(2,0,1).unsqueeze(0)

        self.weight = nn.Parameter(wp)
        self.weight_up = nn.Parameter(wp_up)
        self.weight_down = nn.Parameter(wp_down)
        self.weight_50 = nn.Parameter(wp_50)
        self.weight_30 = nn.Parameter(wp_30)

        gau = torch.from_numpy(gaussian(5,(w.shape[0],w.shape[1]))).unsqueeze(0).unsqueeze(0) 
        gau -= gau.mean()
        gau_up = torch.from_numpy(gaussian(7,(w_up.shape[0],w_up.shape[1]))).unsqueeze(0).unsqueeze(0) 
        gau_up -= gau_up.mean()
        gau_down = torch.from_numpy(gaussian(3,(w_down.shape[0],w_down.shape[1]))).unsqueeze(0).unsqueeze(0) 
        gau_down -= gau_down.mean()
        gau_50 = torch.from_numpy(gaussian(6,(w_50.shape[0],w_50.shape[1]))).unsqueeze(0).unsqueeze(0) 
        gau_50 -= gau_50.mean()
        gau_30 = torch.from_numpy(gaussian(4,(w_30.shape[0],w_30.shape[1]))).unsqueeze(0).unsqueeze(0)
        gau_30 -= gau_30.mean()

        self.weight_gau = nn.Parameter(gau)
        self.weight_gauup = nn.Parameter(gau_up)
        self.weight_gaudown = nn.Parameter(gau_down)
        self.weight_gau50 = nn.Parameter(gau_50)
        self.weight_gau30 = nn.Parameter(gau_30)

    def forward(self,features):
        ft = features.unsqueeze(0).float()

        #original size
        full = self.weight.shape[3]
        half = self.weight.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2): 
            temp = F.conv2d(ft,weight=self.weight[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            outputlist.append(temp.clone())

        outputrank = torch.cat(outputlist,dim=1)
        outputrank,_ = torch.sort(outputrank,axis=1)
        outputrank = outputrank[:,-3:,:,:] 
        outputrank = torch.mean(outputrank,axis=1)
        output = outputrank.unsqueeze(1)/((full)*(full)) 

        #up 31
        full = self.weight_up.shape[3]
        half = self.weight_up.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_up[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gauup[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            outputlist.append(temp.clone())
        outputrank = torch.cat(outputlist,dim=1)
        outputrank,_ = torch.sort(outputrank,axis=1)
        outputrank = outputrank[:,-3:,:,:]
        outputrank = torch.mean(outputrank,axis=1)
        output_up = outputrank.unsqueeze(1)/((full)*(full))

        #down 11
        full = self.weight_down.shape[3]
        half = self.weight_down.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_down[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gaudown[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            outputlist.append(temp.clone())
        outputrank = torch.cat(outputlist,dim=1)
        outputrank,_ = torch.sort(outputrank,axis=1)
        outputrank = outputrank[:,-3:,:,:]
        outputrank = torch.mean(outputrank,axis=1)
        output_down = outputrank.unsqueeze(1)/((full)*(full))

        #50 25
        full = self.weight_50.shape[3]
        half = self.weight_50.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_50[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau50[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            outputlist.append(temp.clone())
        outputrank = torch.cat(outputlist,dim=1)
        outputrank,_ = torch.sort(outputrank,axis=1)
        outputrank = outputrank[:,-3:,:,:]
        outputrank = torch.mean(outputrank,axis=1)
        output_50 = outputrank.unsqueeze(1)/((full)*(full))

        #30 15
        full = self.weight_30.shape[3]
        half = self.weight_30.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_30[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau30[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            outputlist.append(temp.clone())
        outputrank = torch.cat(outputlist,dim=1)
        outputrank,_ = torch.sort(outputrank,axis=1)
        outputrank = outputrank[:,-3:,:,:]
        outputrank = torch.mean(outputrank,axis=1)
        output_30 = outputrank.unsqueeze(1)/((full)*(full))
        
        out = torch.cat((output_down,output_30,output,output_50,output_up),dim=1)

        out,_ = torch.sort(out,axis=1)
        out = out[:,2:5,:,:]
        out = torch.mean(out, axis=1, keepdim=True)
        return out


class Correlation2(nn.Module):
    def __init__(self,w, w_up, w_down, w_50, w_30, w_max, w_min):
        super(Correlation2,self).__init__()

        wp = torch.stack(w, dim=0).permute(0,3,1,2)
        wp_up = torch.stack(w_up, dim=0).permute(0,3,1,2)
        wp_down = torch.stack(w_down, dim=0).permute(0,3,1,2)
        wp_50 = torch.stack(w_50, dim=0).permute(0,3,1,2)
        wp_30 = torch.stack(w_30, dim=0).permute(0,3,1,2)

        self.weight = nn.Parameter(wp)
        self.weight_up = nn.Parameter(wp_up)
        self.weight_down = nn.Parameter(wp_down)
        self.weight_50 = nn.Parameter(wp_50)
        self.weight_30 = nn.Parameter(wp_30)

        gau = torch.from_numpy(gaussian(5,(wp.shape[2],wp.shape[3]))).unsqueeze(0).unsqueeze(0) 
        gau -= gau.mean()
        gau = gau.repeat(wp.shape[0], 1, 1, 1)
        gau_up = torch.from_numpy(gaussian(7,(wp_up.shape[2],wp_up.shape[3]))).unsqueeze(0).unsqueeze(0) 
        gau_up -= gau_up.mean()
        gau_up = gau_up.repeat(wp_up.shape[0], 1, 1, 1)
        gau_down = torch.from_numpy(gaussian(3,(wp_down.shape[2],wp_down.shape[3]))).unsqueeze(0).unsqueeze(0) 
        gau_down -= gau_down.mean()
        gau_down = gau_down.repeat(wp_down.shape[0], 1, 1, 1)
        gau_50 = torch.from_numpy(gaussian(6,(wp_50.shape[2],wp_50.shape[3]))).unsqueeze(0).unsqueeze(0) 
        gau_50 -= gau_50.mean()
        gau_50 = gau_50.repeat(wp_50.shape[0], 1, 1, 1)
        gau_30 = torch.from_numpy(gaussian(4,(wp_30.shape[2],wp_30.shape[3]))).unsqueeze(0).unsqueeze(0)
        gau_30 -= gau_30.mean()
        gau_30 = gau_30.repeat(wp_30.shape[0], 1, 1, 1)

        self.weight_gau = nn.Parameter(gau)
        self.weight_gauup = nn.Parameter(gau_up)
        self.weight_gaudown = nn.Parameter(gau_down)
        self.weight_gau50 = nn.Parameter(gau_50)
        self.weight_gau30 = nn.Parameter(gau_30)

    def forward(self,features):
        ft = features.unsqueeze(0).float()

        #original size
        full = self.weight.shape[3]
        half = self.weight.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2): 
            temp = F.conv2d(ft,weight=self.weight[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau[:,:,i:full-i,i:full-i],stride=1,padding=half-i,groups=self.weight.shape[0])
            outputlist.append(temp.clone())
        outputrank = torch.stack(outputlist, dim=2)
        outputrank,_ = torch.sort(outputrank,dim=2)
        outputrank = outputrank[:,:,-3:,:,:] 
        outputrank = torch.mean(outputrank, dim=2, keepdim=True)
        output = outputrank / (full*full) 
        del outputrank, _, temp, outputlist
        torch.cuda.empty_cache()

        #up 31
        full = self.weight_up.shape[3]
        half = self.weight_up.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_up[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gauup[:,:,i:full-i,i:full-i],stride=1,padding=half-i,groups=self.weight_up.shape[0])
            outputlist.append(temp.clone())
        outputrank = torch.stack(outputlist, dim=2)
        outputrank,_ = torch.sort(outputrank,dim=2)
        outputrank = outputrank[:,:,-3:,:,:]
        outputrank = torch.mean(outputrank, dim=2, keepdim=True)
        output_up = outputrank / (full*full)
        del outputrank, _, temp, outputlist
        torch.cuda.empty_cache()

        #down 11
        full = self.weight_down.shape[3]
        half = self.weight_down.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_down[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gaudown[:,:,i:full-i,i:full-i],stride=1,padding=half-i,groups=self.weight_down.shape[0])
            outputlist.append(temp.clone())
        outputrank = torch.stack(outputlist, dim=2)
        outputrank,_ = torch.sort(outputrank,dim=2)
        outputrank = outputrank[:,:,-3:,:,:]
        outputrank = torch.mean(outputrank, dim=2, keepdim=True)
        output_down = outputrank / (full*full)
        del outputrank, _, temp, outputlist
        torch.cuda.empty_cache()

        #50 25
        full = self.weight_50.shape[3]
        half = self.weight_50.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_50[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau50[:,:,i:full-i,i:full-i],stride=1,padding=half-i,groups=self.weight_50.shape[0])
            outputlist.append(temp.clone())
        outputrank = torch.stack(outputlist, dim=2)
        outputrank,_ = torch.sort(outputrank,dim=2)
        outputrank = outputrank[:,:,-3:,:,:]
        outputrank = torch.mean(outputrank, dim=2, keepdim=True)
        output_50 = outputrank / (full*full)
        del outputrank, _, temp, outputlist
        torch.cuda.empty_cache()

        #30 15
        full = self.weight_30.shape[3]
        half = self.weight_30.shape[3] // 2
        outputlist = []
        for i in range(0,((full-5)//2) + 2,2):
            temp = F.conv2d(ft,weight=self.weight_30[:,:,i:full-i,i:full-i],stride=1,padding=half-i)
            temp = F.conv2d(temp,weight=self.weight_gau30[:,:,i:full-i,i:full-i],stride=1,padding=half-i,groups=self.weight_30.shape[0])
            outputlist.append(temp.clone())
        outputrank = torch.stack(outputlist, dim=2)
        outputrank,_ = torch.sort(outputrank,dim=2)
        outputrank = outputrank[:,:,-3:,:,:]
        outputrank = torch.mean(outputrank, dim=2, keepdim=True)
        output_30 = outputrank / (full*full)
        del outputrank, _, temp, outputlist
        torch.cuda.empty_cache()
        
        out = torch.cat((output_down,output_30,output,output_50,output_up), dim=2)
        del output_down,output_30,output,output_50,output_up
        torch.cuda.empty_cache()

        out, _ = torch.sort(out, dim=2)
        out = out[:,:,-3:,:,:]
        out = torch.mean(out, dim=2, keepdim=False)
        return out