import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class AttentionFusion(nn.Module):
    def __init__(self, in_channels,
                 obj_in_channels,
                 inter_channels=None,
                 dimensions_rgb=3,
                 dimensions_obj=1,
                 sub_sample=False,
                 bn_layer=True):
        super(AttentionFusion, self).__init__()


        self.dimensions_rgb = dimensions_rgb
        self.dimensions_obj = dimensions_obj
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.obj_in_channels=obj_in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1


        self.LN = torch.nn.LayerNorm(self.inter_channels)
        self.vV = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)
        self.vO = nn.Conv1d(in_channels=self.obj_in_channels, out_channels=self.inter_channels,
                            kernel_size=1, stride=1, padding=0)

        self.kqV = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)
        self.kqO = nn.Conv1d(in_channels=self.obj_in_channels, out_channels=self.inter_channels,
                            kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.oW = nn.Sequential(
                nn.Conv1d(in_channels=self.inter_channels, out_channels=self.obj_in_channels,
                        kernel_size=1, stride=1, padding=0),
                nn.BatchNorm1d(self.obj_in_channels)
            )
            nn.init.constant_(self.oW[1].weight, 0)
            nn.init.constant_(self.oW[1].bias, 0)
            self.vW = nn.Sequential(
              nn.Conv1d(in_channels=self.inter_channels , out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
              nn.BatchNorm1d(self.in_channels)
            )
            nn.init.constant_(self.vW[1].weight, 0)
            nn.init.constant_(self.vW[1].bias, 0)

        else:
            self.W = nn.Conv1d(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

    def forward(self, rgb,boxes):
        '''
        :param rgb: (b, c, t, h, w)
        :param boxes: (b, c, t)
        :return:
        '''

        ## G: vV
        ## phi: V
        ## Theta = O

        batch_size,ch,T,H,W = rgb.shape

        vV = self.vV(rgb)
        kqO = self.kqO(boxes)
        kqV = self.kqV(rgb)

        ## rgb PATH
        ## objects attending to rgb
        kqV = F.avg_pool3d(kqV,(1, H,W)).view(batch_size,self.inter_channels,-1) + F.max_pool3d(kqV,(1, H, W)).view(batch_size,self.inter_channels,-1)
        OVT = torch.matmul(kqO.transpose(1,2),kqV)
        OVT = OVT / math.sqrt(512)
        OVT = F.softmax(OVT, dim=-1)
        vV = nn.functional.avg_pool3d(vV, (1, H,W)).view(batch_size, self.inter_channels,-1) + nn.functional.max_pool3d(vV, (1, H, W)).view(batch_size, self.inter_channels, -1)
        OattV = torch.matmul(OVT,vV.transpose(1,2))
        OattV = F.relu(self.LN(OattV))
        obj_att = self.oW(OattV.transpose(1,2))
        Z_obj = boxes + obj_att

        ## rgb attending to obj
        VOT = OVT.transpose(1,2)
        vO = self.vO(boxes)
        VattO = torch.matmul(VOT,vO.transpose(1,2))
        VattO = F.relu(self.LN(VattO))
        rgb_att = self.vW(VattO.transpose(1, 2))
        Z_rgb = rgb * rgb_att.view(batch_size, ch, T, 1, 1).expand_as(rgb)
        return Z_rgb, Z_obj
