import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CrossAttention, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.query_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.key_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.value_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        batch_size, channels, depth, height, width = x.size()
        
        query = self.query_conv(x)
        key = self.key_conv(x)
        value = self.value_conv(x)
        print(query.shape, key.shape, value.shape)
        query = query.view(batch_size, self.out_channels, -1).permute(0, 2, 1)
        key = key.view(batch_size, self.out_channels, -1)
        
        attention_scores = torch.bmm(query, key)
        attention_probs = F.softmax(attention_scores, dim=-1)
        print(attention_probs.shape)
        attention_probs = attention_probs.transpose(1, 2).view(batch_size, depth * height * width, 1)

        output = torch.bmm(value, attention_probs)
        output = output.view(batch_size, self.out_channels, depth, height, width)

        return output.view(batch_size, -1, depth, height, width)

# Usage example
input_tensor = torch.randn(5, 1280, 1, 8, 8)  # Example input tensor
cross_attn = CrossAttention(in_channels=1280, out_channels=1280)
output = cross_attn(input_tensor)
print(output.shape)  # Output shape should be [5, 1280, 1, 8, 8]