import torch
from torchvision.models.efficientnet import FusedMBConvConfig, MBConvConfig, _efficientnet
from torchvision.models import resnet, ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
import math

from config import DefaultConfig

config = DefaultConfig()
device = torch.device(config.device)

width_mult = {
    "efficientnet_b0": 1.0,
    "efficientnet_b1": 1.0,
    "efficientnet_b2": 1.1,
    "efficientnet_b3": 1.2,
    "efficientnet_b4": 1.4,
    "efficientnet_b5": 1.6,
    "efficientnet_b6": 1.8,
    "efficientnet_b7": 2.0,
}

depth_mult = {
    "efficientnet_b0": 1.0,
    "efficientnet_b1": 1.1,
    "efficientnet_b2": 1.2,
    "efficientnet_b3": 1.4,
    "efficientnet_b4": 1.8,
    "efficientnet_b5": 2.2,
    "efficientnet_b6": 2.6,
    "efficientnet_b7": 3.1,
}

eye_configs = {
    "efficientnet_v2_s": [
        FusedMBConvConfig(1, 3, 1, 24, 24, 2),
        FusedMBConvConfig(4, 3, 2, 24, 48, 4),
        FusedMBConvConfig(4, 3, 2, 48, 64, 4),
        MBConvConfig(4, 3, 2, 64, 128, 6),
        MBConvConfig(6, 3, 1, 128, 160, 9),
        MBConvConfig(6, 3, 1, 160, 256, 15)
    ],
    "efficientnet_v2_m": [
        FusedMBConvConfig(1, 3, 1, 24, 24, 3),
        FusedMBConvConfig(4, 3, 2, 24, 48, 5),
        FusedMBConvConfig(4, 3, 2, 48, 80, 5),
        MBConvConfig(4, 3, 2, 80, 160, 7),
        MBConvConfig(6, 3, 1, 160, 176, 14),
        MBConvConfig(6, 3, 1, 176, 304, 18),
        MBConvConfig(6, 3, 1, 304, 512, 5)
    ],
    "efficientnet_v2_l": [
        FusedMBConvConfig(1, 3, 1, 32, 32, 4),
        FusedMBConvConfig(4, 3, 2, 32, 64, 7),
        FusedMBConvConfig(4, 3, 2, 64, 96, 7),
        MBConvConfig(4, 3, 2, 96, 192, 10),
        MBConvConfig(6, 3, 1, 192, 224, 19),
        MBConvConfig(6, 3, 1, 224, 384, 25),
        MBConvConfig(6, 3, 1, 384, 640, 7),
    ],
    "last_channel": 128
}

face_configs = {
    "efficientnet_v2_s": [
        FusedMBConvConfig(1, 3, 1, 24, 24, 2),
        FusedMBConvConfig(4, 3, 2, 24, 48, 4),
        FusedMBConvConfig(4, 3, 2, 48, 64, 4),
        MBConvConfig(4, 3, 2, 64, 128, 6),
        MBConvConfig(6, 3, 1, 128, 160, 9),
        MBConvConfig(6, 3, 1, 160, 256, 15)
    ],
    "efficientnet_v2_m": [
        FusedMBConvConfig(1, 3, 1, 24, 24, 3),
        FusedMBConvConfig(4, 3, 2, 24, 48, 5),
        FusedMBConvConfig(4, 3, 2, 48, 80, 5),
        MBConvConfig(4, 3, 2, 80, 160, 7),
        MBConvConfig(6, 3, 1, 160, 176, 14),
        MBConvConfig(6, 3, 1, 176, 304, 18),
        MBConvConfig(6, 3, 1, 304, 512, 5)
    ],
    "efficientnet_v2_l": [
        FusedMBConvConfig(1, 3, 1, 32, 32, 4),
        FusedMBConvConfig(4, 3, 2, 32, 64, 7),
        FusedMBConvConfig(4, 3, 2, 64, 96, 7),
        MBConvConfig(4, 3, 2, 96, 192, 10),
        MBConvConfig(6, 3, 1, 192, 224, 19),
        MBConvConfig(6, 3, 1, 224, 384, 25),
        MBConvConfig(6, 3, 1, 384, 640, 7),
    ],
    "last_channel": 32
}

def _efficientnet_conf(model_name, eye=False):
    if model_name.startswith("efficientnet_b"):
        bneck_conf = partial(MBConvConfig, width_mult=width_mult[model_name], depth_mult=depth_mult[model_name])
        inverted_residual_setting = [
            bneck_conf(1, 3, 1, 32, 16, 1),
            bneck_conf(6, 3, 2, 16, 24, 2),
            bneck_conf(6, 5, 1, 24, 40, 2),
            bneck_conf(6, 3, 2, 40, 80, 3),
            bneck_conf(6, 5, 1, 80, 112, 3),
            bneck_conf(6, 5, 2, 112, 192, 4),
            bneck_conf(6, 3, 1, 192, 320, 1)
        ]
    else:
        if eye:
            inverted_residual_setting = eye_configs[model_name]
        else:
            inverted_residual_setting = face_configs[model_name]
    return inverted_residual_setting

class EfficientNetEncoder(nn.Module):
    """
    EfficientNet Encoder.
    
    Args:
        model_name: str, model name for EfficientNet encoder.
        eye: bool, whether to use eye or face configuration.
        dropout: float, dropout rate.
        weights: str, path to weights file.
    """
    def __init__(self, model_name, eye=False, dropout=0.2, weights=None):
        super(EfficientNetEncoder, self).__init__()
        self.config = _efficientnet_conf(model_name, eye)
        if eye:
            self.last_channel = eye_configs["last_channel"]
        else:
            self.last_channel = face_configs["last_channel"]
        # print(self.config)
        self.dropout = dropout
        self.model = _efficientnet(self.config, dropout=self.dropout, last_channel=self.last_channel, 
                                   weights=weights, progress=True, norm_layer=partial(nn.BatchNorm2d, eps=1e-03)).features
        self.model_name = model_name
        self.eye = eye

    def forward(self, x):
        return self.model(x)
    
class ResNetEncoder(nn.Module):
    """
    ResNet Encoder.
    
    Args:
        model_name: str, model name for ResNet encoder.
        eye: bool, whether to use eye or face configuration.
        dropout: float, dropout rate.
        pretrained: bool, whether to use pretrained weights.
    """
    def __init__(self, model_name, eye=False, pretrained=True):
        super(ResNetEncoder, self).__init__()
        if model_name == "resnet18":
            full_model = resnet.resnet18(weights=ResNet18_Weights.DEFAULT if pretrained else None)
        elif model_name == "resnet34":
            full_model = resnet.resnet34(weights=ResNet34_Weights.DEFAULT if pretrained else None)
        elif model_name == "resnet50":
            full_model = resnet.resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
        elif model_name == "resnet101":
            full_model = resnet.resnet101(weights=ResNet101_Weights.DEFAULT if pretrained else None)
        elif model_name == "resnet152":
            full_model = resnet.resnet152(weights=ResNet152_Weights.DEFAULT if pretrained else None)
        else:
            raise ValueError("Invalid model name.")
        modules = list(full_model.children())[:-3]
        last_channel = modules[-1][-1].bn2.num_features
        # Add convolutional layer to reduce channel dimension to 128 for eye or 32 for face
        if eye:
            modules.append(nn.Conv2d(last_channel, 128, kernel_size=1, stride=1, padding=0))
        else:
            modules.append(nn.Conv2d(last_channel, 32, kernel_size=1, stride=1, padding=0))
        self.model = nn.Sequential(*modules)
        self.eye = eye
        self.model_name = model_name
    
    def forward(self, x):
        return self.model(x)

# ST Gaze Net model, eye and face encoders, concatenated and passed through a ECA (efficient channel attention) module (To be implemented), GRU cell and a FC layer

class ECA_Module(nn.Module):
    """
    Efficient Channel Attention (ECA) Module.
    
    Args:
        channels: int, number of input channels.
        gamma: float, gamma value for kernel size calculation.
        b: float, b value for kernel size calculation.
    """
    def __init__(self, channels, gamma=2, b=1):
        super(ECA_Module, self).__init__()
        # Adaptive kernel size
        self.kernel_size = int(abs((math.log2(channels) / gamma) + (b / gamma)) // 2 * 2 + 1)  # Ensures odd kernel size
        self.conv1d = nn.Conv1d(1, 1, kernel_size=self.kernel_size, padding=(self.kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Global Average Pooling: (B, C, H, W) -> (B, C, 1, 1)
        b, c, h, w = x.size()
        gap = torch.mean(x, dim=(2, 3), keepdim=True)
        # Transpose for 1D convolution: (B, C, 1, 1) -> (B, C, 1)
        gap = gap.view(b, 1, c)
        # Apply 1D convolution for channel attention
        attn = self.conv1d(gap)
        # Apply Sigmoid activation
        attn = self.sigmoid(attn)
        # Reshape attention weights: (B, C, 1) -> (B, C, 1, 1)
        attn = attn.view(b, c, 1, 1)
        # Channel-wise multiplication
        y = x * attn
        return y

class TransformerEncoderBlock(nn.Module):
    """
    Transformer Encoder Block.
    
    Args:
        embed_dim: int, embedding dimension.
        num_heads: int, number of attention heads.
        ff_dim: int, feedforward dimension.
        dropout: float, dropout rate.
    """
    def __init__(self, embed_dim=160, num_heads=8, ff_dim=512, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
        )

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # Self-Attention with residual and layer norm
        attn_output, _ = self.self_attn(x, x, x)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        # Feedforward with residual and layer norm
        ff_output = self.ff(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)

        return x

class SelfAttentionModule(nn.Module):
    """
    Self-Attention Module.

    Args:
        feature_dim: int, feature dimension.
        seq_len: int, sequence length.
        ff_dim: int, feedforward dimension.
        dropout: float, dropout rate.
        num_heads: int, number of attention heads.
        num_layers: int, number of transformer encoder layers.
    """
    def __init__(self, feature_dim=160, seq_len=64, ff_dim=512, dropout=0.1, num_heads=8, num_layers=4):
        super(SelfAttentionModule, self).__init__()
        self.seq_len = seq_len
        self.feature_dim = feature_dim
        self.ff_dim = ff_dim
        self.dropout = dropout
        self.num_heads = num_heads
        # Learnable f_i token
        self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim))
        # Positional Encoding
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_len + 1, feature_dim))
        # Transformer Encoder Blocks
        self.transformers = nn.ModuleList([
            TransformerEncoderBlock(feature_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, Y_i):
        B, C, H, W = Y_i.shape
        assert C == self.feature_dim, f"Expected feature dimension {self.feature_dim}, but got {C}."
        assert H * W == self.seq_len, f"Expected sequence length {self.seq_len}, but got {H * W}."
        ## Step 1: Divide Y^(i) into one-dimensional vectors (Shape: [batch_size, seq_len, feature_dim])
        Y_i_reshaped = Y_i.reshape(B, self.feature_dim, self.seq_len).transpose(1, 2) # Shape: [batch_size, feature_dim (C), H, W] -> [batch_size, seq_len (S), feature_dim (C)]
        # Step 2: Concatenate f_i token with the sequence of vectors
        cls_token = self.cls_token.expand(B, 1, C)  # (B, 1, C)
        x = torch.cat([cls_token, Y_i_reshaped], dim=1)        # (B, S+1, C)
        # Step 3: Add positional encoding
        x = x + self.positional_encoding#[:, :X.size(1), :]
        # Step 4: Pass through Transformer Encoder Blocks
        for layer in self.transformers:
            x = layer(x)
        # Step 5: Return the output
        return x[:, 1:, :]  # Remove the first token (f_i)
    
def pitchyaw_to_vector(py):
    """
    Convert pitchyaw angles to unit vectors.
    
    Args:
        py: Tensor of shape (N, 2) or (N,S,2) containing pitch and yaw angles.

    Returns:
        v: Tensor of shape (N, 3) or (N,S,3) containing unit vectors in 3D space.
    """
    if py.dim() == 2:
        pitch, yaw = py[:, 0], py[:, 1]
    elif py.dim() == 3: # shape (batch, sequence, gaze)
        pitch, yaw = py[:, :, 0], py[:, :, 1]
    v = torch.stack([
        torch.cos(pitch) * torch.sin(yaw),
        torch.sin(pitch),
        torch.cos(pitch) * torch.cos(yaw)
    ], dim=-1)
    return v

def vector_to_pitchyaw(v):
    """
    Convert unit vectors to pitchyaw angles.
    
    Args:
        v: Tensor of shape (N, 3) or (N,S,3) containing unit vectors in 3D space.

    Returns:
        py: Tensor of shape (N, 2) or (N,S,2) containing pitch and yaw angles.
    """
    if v.dim() == 2:
        x, y, z = v[:, 0], v[:, 1], v[:, 2]
    elif v.dim() == 3: # shape (batch, sequence, gaze)
        x, y, z = v[:, :, 0], v[:, :, 1], v[:, :, 2]
    pitch = torch.atan2(y, torch.sqrt(x**2 + z**2))
    yaw = torch.atan2(x, z)
    return torch.stack([pitch, yaw], dim=-1)

### Code from EVE below

"""Copyright 2020 ETH Zurich, Seonwook Park

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

nn_plane_normal = None
nn_plane_other = None

def to_screen_coordinates(origin, direction, rotation, inv_camera_transformation, pixels_per_millimeter):
    """
    Convert gaze direction to screen coordinates.
    
    Args:
        origin: Tensor of shape (N, 3) containing origin of gaze rays.
        direction: Tensor of shape (N, 2) containing gaze directions as pitchyaw angles.
        rotation: Tensor of shape (N, 3, 3) containing the rotation correction applied to the raw gaze direction vector.
        inv_camera_transformation: Tensor of shape (N, 4, 4) containing the inverse camera transformation matrix.
        pixels_per_millimeter: Tensor of shape (N, 2) containing pixels per millimeter for width and height.
    
    Returns:
        PoG_mm: Tensor of shape (N, 2) containing gaze points in millimeters.
        PoG_px: Tensor of shape (N, 2) containing gaze points in pixels.
    """
    direction = pitchyaw_to_vector(direction)

    # Send direction back to camera
    direction = -direction

    # Rotate gaze direction
    inv_rotation = torch.transpose(rotation, 1, 2)
    direction = direction.reshape(-1, 3, 1)
    direction = torch.matmul(inv_rotation, direction)

    # Transform values
    direction = apply_rotation(inv_camera_transformation, direction)
    origin = apply_transformation(inv_camera_transformation, origin)

    # Intersect with z = 0
    recovered_target_2D = get_intersect_with_zero(origin, direction)
    PoG_mm = recovered_target_2D

    # Convert back from mm to pixels
    ppm_w = pixels_per_millimeter[:, 0]
    ppm_h = pixels_per_millimeter[:, 1]
    PoG_px = torch.stack([
        torch.clamp(recovered_target_2D[:, 0] * ppm_w,
                    0.0, float(config.actual_screen_size[0])),
        torch.clamp(recovered_target_2D[:, 1] * ppm_h,
                    0.0, float(config.actual_screen_size[1]))
    ], axis=-1)

    return PoG_mm, PoG_px

def apply_rotation(T, vec):
    """
    Applies a rotation transformation to a vector.

    Parameters:
    T (torch.Tensor): A tensor of shape (N, 4, 4) representing the transformation matrices.
    vec (torch.Tensor): A tensor of shape (N, 3) or (N, 2) representing the vectors to be rotated.
                        If the vectors are in (N, 2) shape, they will be converted to (N, 3) using pitchyaw_to_vector.

    Returns:
    torch.Tensor: A tensor of shape (N, 3) representing the rotated vectors.
    """
    if vec.shape[1] == 2:
        vec = pitchyaw_to_vector(vec)
    vec = vec.reshape(-1, 3, 1)
    R = T[:, :3, :3]
    return torch.matmul(R, vec).reshape(-1, 3)

def apply_transformation(T, vec):
    """
    Apply a transformation matrix to a vector.

    Args:
        T (torch.Tensor): A transformation matrix of shape (N, 4, 4).
        vec (torch.Tensor): A vector of shape (N, 2) or (N, 3).

    Returns:
        torch.Tensor: The transformed vector of shape (N, 3).
    """
    if vec.shape[1] == 2:
        vec = pitchyaw_to_vector(vec)
    vec = vec.reshape(-1, 3, 1)
    h_vec = F.pad(vec, pad=(0, 0, 0, 1), value=1.0)
    return torch.matmul(T, h_vec)[:, :3, 0]

def pitchyaw_to_rotation(a):
    if a.shape[1] == 3:
        a = vector_to_pitchyaw(a)

    cos = torch.cos(a)
    sin = torch.sin(a)
    ones = torch.ones_like(cos[:, 0])
    zeros = torch.zeros_like(cos[:, 0])
    matrices_1 = torch.stack([ones, zeros, zeros,
                              zeros, cos[:, 0], sin[:, 0],
                              zeros, -sin[:, 0], cos[:, 0]
                              ], dim=1)
    matrices_2 = torch.stack([cos[:, 1], zeros, sin[:, 1],
                              zeros, ones, zeros,
                              -sin[:, 1], zeros, cos[:, 1]
                              ], dim=1)
    matrices_1 = matrices_1.view(-1, 3, 3)
    matrices_2 = matrices_2.view(-1, 3, 3)
    matrices = torch.matmul(matrices_2, matrices_1)
    return matrices

def get_intersect_with_zero(o, g):
    """
    Intersects a given gaze ray (origin o and direction g) with the plane z = 0.
    
    Args:
        o: Tensor of shape (N, 3) containing origin of gaze rays.
        g: Tensor of shape (N, 3) containing gaze directions.

    Returns:
        Tensor of shape (N, 2) containing the intersection points
    """
    global nn_plane_normal, nn_plane_other
    if nn_plane_normal is None:
        nn_plane_normal = torch.tensor([0, 0, 1], dtype=torch.float32, device=device).view(1, 3, 1)
        nn_plane_other = torch.tensor([1, 0, 0], dtype=torch.float32, device=device).view(1, 3, 1)

    # Define plane to intersect with
    n = nn_plane_normal
    a = nn_plane_other
    g = g.view(-1, 3, 1)
    o = o.view(-1, 3, 1)
    numer = torch.sum(torch.mul(a - o, n), dim=1)

    # Intersect with plane using provided 3D origin
    denom = torch.sum(torch.mul(g, n), dim=1) + 1e-7
    t = torch.div(numer, denom).view(-1, 1, 1)
    return (o + torch.mul(t, g))[:, :2, 0]

def apply_offset_augmentation(gaze_direction, head_rotation, kappa, inverse_kappa=False):
    gaze_direction = pitchyaw_to_vector(gaze_direction)

    # Negate gaze vector back (to camera perspective)
    gaze_direction = -gaze_direction

    # De-rotate gaze vector
    inv_head_rotation = torch.transpose(head_rotation, 1, 2)
    gaze_direction = gaze_direction.reshape(-1, 3, 1)
    gaze_direction = torch.matmul(inv_head_rotation, gaze_direction)

    # Negate gaze vector back (to user perspective)
    gaze_direction = -gaze_direction

    # Apply kappa to frontal vector [0 0 1]
    kappa_vector = pitchyaw_to_vector(kappa).reshape(-1, 3, 1)
    if inverse_kappa:
        kappa_vector = torch.cat([
            -kappa_vector[:, :2, :], kappa_vector[:, 2, :].reshape(-1, 1, 1),
        ], axis=1)

    # Apply head-relative gaze to rotated frontal vector
    head_relative_gaze_rotation = pitchyaw_to_rotation(vector_to_pitchyaw(gaze_direction.squeeze(-1)))
    gaze_direction = torch.matmul(head_relative_gaze_rotation, kappa_vector)

    # Negate gaze vector back (to camera perspective)
    gaze_direction = -gaze_direction

    # Rotate gaze vector back
    gaze_direction = gaze_direction.reshape(-1, 3, 1)
    gaze_direction = torch.matmul(head_rotation, gaze_direction)

    # Negate gaze vector back (to user perspective)
    gaze_direction = -gaze_direction

    gaze_direction = vector_to_pitchyaw(gaze_direction.squeeze(-1))
    return gaze_direction