import torch
from torch import nn
import math
from .commons import eye_configs, face_configs, ECA_Module, SelfAttentionModule, EfficientNetEncoder, ResNetEncoder, to_screen_coordinates
from torchvision.models import resnet
from .losses import angular_error, PoG_loss, consistency_loss
from config import DefaultConfig

config = DefaultConfig()

device = torch.device(config.device)

RESNET_FEATURE_DIM = {
    "eye_encoder": 128,
    "face_encoder": 32
}

class STGazeNet(nn.Module):
    """
    ST gaze network that uses EfficientNet encoder for eye and face patches.
    Concatenates the features and passes them through a channel attention module (ECA), 
    self-attention module, GRU cell and a fully connected layer.
    
    Args:
        model_name: str, model name for EfficientNet encoder.
        dropout: float, dropout rate.

    Returns:
        Gaze prediction in radians.
    """
    def __init__(self, model_name=None, dropout=None):
        super(STGazeNet, self).__init__()
        self.model_name = model_name or config.st_net_model_name
        self.dropout = dropout or config.st_net_dropout
        if self.model_name.startswith("efficientnet"):
            self.eye_encoder = EfficientNetEncoder(self.model_name, eye=True, dropout=self.dropout)
            self.face_encoder = EfficientNetEncoder(self.model_name, eye=False, dropout=self.dropout)
            self.feature_dim = eye_configs["last_channel"] + face_configs["last_channel"]
        elif self.model_name.startswith("resnet"):
            self.eye_encoder = ResNetEncoder(self.model_name, eye=True)
            self.face_encoder = ResNetEncoder(self.model_name, eye=False)
            self.feature_dim = RESNET_FEATURE_DIM["eye_encoder"] + RESNET_FEATURE_DIM["face_encoder"]
        self.seq_len = 64
        self.eca = ECA_Module(self.feature_dim)
        self.sam = SelfAttentionModule(
            feature_dim=self.feature_dim,
            seq_len=self.seq_len,
            ff_dim=config.st_net_transformer_ffn_dim,
            dropout=self.dropout,
            num_heads=config.st_net_transformer_num_heads,
            num_layers=config.st_net_transformer_num_layers
        )
        self.gru = nn.GRU(
            input_size=self.feature_dim,
            hidden_size=self.feature_dim,
            num_layers=config.st_net_gru_num_cells,
            batch_first=True
        )
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        # self.avgpool = nn.AdaptiveAvgPool2d((1, self.feature_dim))
        self.fc_to_gaze = nn.Sequential(
            nn.Linear(self.feature_dim, config.st_net_static_num_features),
            nn.SELU(inplace=True),
            nn.Linear(config.st_net_static_num_features, 2, bias=False),
            nn.Tanh(),
        )
        # Set gaze layer weights to small non-zero values as otherwise this can explode early in training
        nn.init.constant_(self.fc_to_gaze[2].weight, 1e-4)
        
    def forward(self, input_dict, side, hidden_state=None, flip=False):
        """
        Forward pass of the network.
        (input/output shapes only valid if using 'config = DefaultConfig()')

        Args:
            input_dict (dict): Dictionary containing input tensors as described below.
            - {side}_eye_patch (Tensor): Cropped images of {side} eye (B, 3, 128, 128).
            - face_patch (Tensor): Same shape as eye_patch.
            - face_features (Tensor, optional): Precomputed face features. Shape (B, C, H, W).
            side (str): 'left' or 'right', indicating which eye to process.
            hidden_state: Tensor (2, B, 160) containing hidden state of GRU cell.
            flip: Bool, whether to flip the gaze prediction.

        Returns:
            gaze_prediction: Tensor (B, 2) containing predicted gaze angles in radians.
            hidden_state: Tensor (2, B, 160) containing hidden state of GRU cell.
        """
        eye_features = self.eye_encoder(input_dict[f"{side}_eye_patch"]) # Shape of eye features: torch.Size(B, 128, 8, 8)
        if "face_features" in input_dict:
            face_features = input_dict["face_features"]
        else:
            face_features = self.face_encoder(input_dict["face_patch"]) # Shape of face features: torch.Size(B, 32, 8, 8)
            input_dict["face_features"] = face_features # Store face features in input dict for later use
        features = torch.cat([eye_features, face_features], dim=1) # Shape of concatenated features: torch.Size([B, 160, 8, 8])
        features = self.eca(features) # Shape of features after ECA: torch.Size([B, 160, 8, 8])
        features = self.sam(features) # Shape of features after SAM: torch.Size([B, 64, 160])
        if config.st_net_pool_before_gru:
            features = self.avgpool(features.transpose(1, 2)) # Shape of avg pooled features: torch.Size([B, 160, 1])
            features = features.transpose(1, 2) # Shape of avg pooled features: torch.Size([B, 1, 160])
            features, hidden_state = self.gru(features, hidden_state) # Shape of features after GRU: torch.Size([B, 1, 160])
            features = features.transpose(1, 2) # Shape of features after GRU: torch.Size([B, 160, 1])
        else:
            features, hidden_state = self.gru(features, hidden_state) # Shape of features after GRU: torch.Size([B, 64, 160])
            features = self.avgpool(features.transpose(1, 2)) # Shape of avg pooled features: torch.Size([B, 160, 1])
        features = self.fc_to_gaze(features.squeeze(-1)) # Shape of predicted Gaze: torch.Size([B, 2])

        # Final prediction
        gaze_prediction = math.pi / 2 * features

        if flip:
            # Mirror the gaze prediction around the y-axis
            # gaze_prediction[:, 0] = -gaze_prediction[:, 0]
            gaze_prediction[:, 1] = -gaze_prediction[:, 1]

        # Detach gradients if network is frozen and in salient training mode
        if config.st_net_frozen:
            gaze_prediction = gaze_prediction.detach()
            hidden_state = hidden_state.detach()

        # return gaze_prediction.to("cpu"), hidden_state.to("cpu")
        return gaze_prediction, hidden_state
    
    def loss(self, input_dict, output_dict, reduction='mean'):
        """
        Compute angular error loss on gaze direction and MSE loss on PoG.

        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss: Dict, containing angular error loss and MSE loss on PoG.
        """
        loss_dict = {}
        batch, seq_len, _ = input_dict["left_PoG_tobii"].shape
        loss_dict["loss"] = torch.zeros((batch*seq_len)).to(device)
        for side in config.sides:
            gaze = input_dict[f"{side}_g_tobii"]
            gaze_prediction = output_dict[f"{side}_gaze"]
            loss_dict[f"{side}_loss_ang"] = angular_error(gaze, gaze_prediction, input_dict[f"{side}_g_tobii_validity"])
            loss_dict[f"{side}_loss_pog_px"] = PoG_loss(input_dict[f"{side}_PoG_tobii"], output_dict[f"{side}_PoG_px"], input_dict[f"{side}_PoG_tobii_validity"])
            pog_cm_tobii = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1*input_dict['millimeters_per_pixel'])
            loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(pog_cm_tobii, output_dict[f"{side}_PoG_mm"]*0.1, input_dict[f"{side}_PoG_tobii_validity"])

        if len(config.sides) == 2:
            gaze = (input_dict["left_g_tobii"] + input_dict["right_g_tobii"]) / 2.
            gaze_prediction = (output_dict["left_gaze"] + output_dict["right_gaze"]) / 2.
            validity = input_dict["left_g_tobii_validity"] * input_dict["right_g_tobii_validity"]
            loss_dict["loss_ang"] = (angular_error(gaze, gaze_prediction, validity))
            loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial

            pog_px = (input_dict["left_PoG_tobii"] + input_dict["right_PoG_tobii"]) / 2.
            pog_px_prediction = (output_dict["left_PoG_px"] + output_dict["right_PoG_px"]) / 2.
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_pog_px"] = PoG_loss(pog_px, pog_px_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial

            if "face_PoG_px" in output_dict:
                loss_dict["loss_face_pog_px"] = PoG_loss(input_dict["face_PoG_tobii"], output_dict["face_PoG_px"], input_dict["face_PoG_tobii_validity"])
                loss_dict["loss_face_ang"] = angular_error(input_dict["face_g_tobii"], output_dict["face_gaze"], input_dict["face_g_tobii_validity"])

            pog_cm = torch.mul(pog_px, 0.1*input_dict['millimeters_per_pixel'])
            pog_cm_prediction = (output_dict["left_PoG_mm"] + output_dict["right_PoG_mm"]) / 2. * 0.1
            loss_dict["loss_pog_cm"] = PoG_loss(pog_cm, pog_cm_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial

            loss_dict["loss_consistency"] = consistency_loss(output_dict["left_PoG_px"], output_dict["right_PoG_px"], validity)
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial

        loss_dict["loss"] = loss_dict["loss"] / (config.loss_coeff_g_ang_initial + config.loss_coeff_PoG_px_initial + config.loss_coeff_PoG_cons_initial + config.loss_coeff_PoG_cm_initial)
        if reduction == 'mean':
            # Average loss over batch before returning
            return {k: v.mean() for k, v in loss_dict.items()} #
        elif reduction == 'sum':
            # Sum loss over batch before returning
            return {k: v.sum() for k, v in loss_dict.items()}
        else:
            # No reduction
            return loss_dict
        
    def loss_eyes(self, input_dict, output_dict, reduction='mean'):
        """
        Compute angular error loss on gaze direction and MSE loss on PoG.

        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss: Dict, containing angular error loss and MSE loss on PoG.
        """
        loss_dict = {}
        batch, seq_len, _ = input_dict["left_PoG_tobii"].shape
        loss_dict["loss"] = torch.zeros((batch*seq_len)).to(device)
        for side in config.sides:
            gaze = input_dict[f"{side}_g_tobii"]
            gaze_prediction = output_dict[f"{side}_gaze"]
            loss_dict[f"{side}_loss_ang"] = angular_error(gaze, gaze_prediction, input_dict[f"{side}_g_tobii_validity"])
            loss_dict[f"{side}_loss_pog_px"] = PoG_loss(input_dict[f"{side}_PoG_tobii"], output_dict[f"{side}_PoG_px"], input_dict[f"{side}_PoG_tobii_validity"])
            pog_cm_tobii = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1*input_dict['millimeters_per_pixel'])
            loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(pog_cm_tobii, output_dict[f"{side}_PoG_mm"]*0.1, input_dict[f"{side}_PoG_tobii_validity"])
            if config.loss_separate:
                loss_dict["loss"] += loss_dict[f"{side}_loss_pog_px"] * config.loss_coeff_PoG_px_initial * config.loss_coeff_g_eyes
                loss_dict["loss"] += loss_dict[f"{side}_loss_ang"] * config.loss_coeff_g_ang_initial * config.loss_coeff_g_eyes

        if len(config.sides) == 2:
            gaze = (input_dict["left_g_tobii"] + input_dict["right_g_tobii"]) / 2.
            gaze_prediction = (output_dict["left_gaze"] + output_dict["right_gaze"]) / 2.
            validity = input_dict["left_g_tobii_validity"] * input_dict["right_g_tobii_validity"]
            validity = validity * input_dict["face_g_tobii_validity"] if "face_g_tobii_validity" in input_dict else validity
            loss_dict["loss_ang"] = (angular_error(gaze, gaze_prediction, validity))
            loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial * config.loss_coeff_g_face

            pog_px = (input_dict["left_PoG_tobii"] + input_dict["right_PoG_tobii"]) / 2.
            pog_px_prediction = (output_dict["left_PoG_px"] + output_dict["right_PoG_px"]) / 2.
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_pog_px"] = PoG_loss(pog_px, pog_px_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial * config.loss_coeff_g_face

            if "face_PoG_px" in output_dict:
                loss_dict["loss_face_pog_px"] = PoG_loss(input_dict["face_PoG_tobii"], output_dict["face_PoG_px"], input_dict["face_PoG_tobii_validity"])
                loss_dict["loss_face_ang"] = angular_error(input_dict["face_g_tobii"], output_dict["face_gaze"], input_dict["face_g_tobii_validity"])

            pog_cm = torch.mul(pog_px, 0.1*input_dict['millimeters_per_pixel'])
            pog_cm_prediction = (output_dict["left_PoG_mm"] + output_dict["right_PoG_mm"]) / 2. * 0.1
            loss_dict["loss_pog_cm"] = PoG_loss(pog_cm, pog_cm_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial

            loss_dict["loss_consistency"] = consistency_loss(output_dict["left_PoG_px"], output_dict["right_PoG_px"], validity)
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial

        loss_dict["loss"] = loss_dict["loss"] / (config.loss_coeff_g_ang_initial + config.loss_coeff_PoG_px_initial + config.loss_coeff_PoG_cons_initial + config.loss_coeff_PoG_cm_initial)
        if reduction == 'mean':
            # Average loss over batch before returning
            return {k: v.mean() for k, v in loss_dict.items()} #
        elif reduction == 'sum':
            # Sum loss over batch before returning
            return {k: v.sum() for k, v in loss_dict.items()}
        else:
            # No reduction
            return loss_dict

class WrappedSGN(nn.Module):
    """
    Wrapper for STGazeNet to allow for easy use with EVE test bed.
    """
    def __init__(self, model_name=None, dropout=None):
        super(WrappedSGN, self).__init__()
        self.model = STGazeNet(model_name=model_name, dropout=dropout)
    
    def forward(self, input_dict, output_dict=None, previous_output_dict=None, side=None):
        """
        Forward pass of the network.
        
        Args:
            input_dict: dict, input dictionary containing eye and face patches.
            output_dict: dict, output dictionary containing predicted gaze.
            previous_output_dict: dict, previous output dictionary containing predicted gaze.
            side: str, side of the gaze prediction ('left' or 'right').
        """
        hidden_state = previous_output_dict[f"{side}_hidden_state"] if previous_output_dict is not None else None
        flip = side == "right" and config.flip_right_eye
        gaze_prediction, hidden_state = self.model(input_dict, side, hidden_state=hidden_state, flip=flip)
        output_dict[f"{side}_g_initial"] = gaze_prediction
        output_dict[f"{side}_hidden_state"] = hidden_state
        # Add placeholder for pupil size
        output_dict[side + '_pupil_size'] = torch.zeros((gaze_prediction.size(0), 1)).to(device)

# For the ablation study let's create a testing class allowing us to omit part of the network
class STGazeNetAblation(nn.Module):
    """
    Implementation of STGazeNet for ablation study.
    Each part of the network can be omitted by setting the corresponding flag.
    
    Args:
        model_name: str, model name for EfficientNetV2 encoder.
        dropout: float, dropout rate.
        eye_encoder: bool, whether to use eye encoder.
        face_encoder: bool, whether to use face encoder.
        eca: bool, whether to use ECA module.
        sam: bool, whether to use self-attention module.
        gru: bool, whether to use GRU cell.
        
    Returns:
        Gaze prediction in radians.
    """
    def __init__(self, model_name=None, dropout=None, eye_encoder=True, face_encoder=True, eca=True, sam=True, gru=True):
        super(STGazeNetAblation, self).__init__()
        self.model_name = model_name or config.st_net_model_name
        self.dropout = dropout or config.st_net_dropout
        if self.model_name.startswith("efficientnet"):
            self.eye_encoder = EfficientNetEncoder(self.model_name, eye=True, dropout=self.dropout) if eye_encoder else None
            self.face_encoder = EfficientNetEncoder(self.model_name, eye=False, dropout=self.dropout) if face_encoder else None
            self.feature_dim = eye_configs["last_channel"] * eye_encoder + face_configs["last_channel"] * face_encoder
        elif self.model_name.startswith("resnet"):
            self.eye_encoder = ResNetEncoder(self.model_name, eye=True, pretrained=config.st_net_load_pretrained) if eye_encoder else None
            self.face_encoder = ResNetEncoder(self.model_name, eye=False, pretrained=config.st_net_load_pretrained) if face_encoder else None
            self.feature_dim = RESNET_FEATURE_DIM["eye_encoder"] * eye_encoder + RESNET_FEATURE_DIM["face_encoder"] * face_encoder
        self.seq_len = 64
        self.eca = ECA_Module(self.feature_dim) if eca else None
        self.sam = SelfAttentionModule(
            feature_dim=self.feature_dim,
            seq_len=self.seq_len,
            ff_dim=config.st_net_transformer_ffn_dim,
            dropout=self.dropout,
            num_heads=config.st_net_transformer_num_heads,
            num_layers=config.st_net_transformer_num_layers
        ) if sam else None
        self.gru = nn.GRU(
            input_size=self.feature_dim,
            hidden_size=self.feature_dim,
            num_layers=config.st_net_gru_num_cells,
            batch_first=True
        ) if gru else None
        # Add average pooling layer
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc_to_gaze = nn.Sequential(
            nn.Linear(self.feature_dim, config.st_net_static_num_features),
            nn.SELU(inplace=True),
            nn.Linear(config.st_net_static_num_features, 2, bias=False),
            nn.Tanh(),
        )
        # Set gaze layer weights to small non-zero values as otherwise this can explode early in training
        nn.init.constant_(self.fc_to_gaze[2].weight, 1e-4)

    def forward(self, input_dict, side, hidden_state=None, flip=False):
        """
        Forward pass of the ablation network.
        (input/output shapes only valid if using 'config = DefaultConfig()')

        Args:
            input_dict (dict): Dictionary containing input tensors as described below.
            - {side}_eye_patch (Tensor): For training, (B, T, 3, 128, 128). For inference, (1, 1, 3, 128, 128).
            - face_patch (Tensor): Same shape as eye_patch.
            - face_features (Tensor, optional): Precomputed face features. Shape (B, C, H, W).
            side (str): 'left' or 'right', indicating which eye to process.
            hidden_state: Tensor (2, N, 160) containing hidden state of GRU cell.
            flip: Bool, whether to flip the gaze prediction.

        Returns:
            gaze_prediction: Tensor (N, 2) containing predicted gaze angles in radians.
            hidden_state: Tensor (2, N, 160) containing hidden state of GRU cell.
        """
        features = []
        if self.eye_encoder is not None:
            features.append(self.eye_encoder(input_dict[f"{side}_eye_patch"]))
        if self.face_encoder is not None:
            if "face_features" in input_dict:
                features.append(input_dict["face_features"])
            else:
                features.append(self.face_encoder(input_dict["face_patch"]))
                input_dict["face_features"] = features[-1]
        # print([feature.size() for feature in features])
        features = torch.cat(features, dim=1)
        if self.eca is not None:
            features = self.eca(features)
        if self.sam is not None:
            features = self.sam(features)
        else:
            features = features.flatten(start_dim=2).transpose(1, 2)
        if self.gru is not None:
            features, hidden_state = self.gru(features, hidden_state)
        features = self.avgpool(features.transpose(1, 2))
        features = self.fc_to_gaze(features.squeeze(-1))

        # Final prediction
        gaze_prediction = math.pi / 2 * features

        if flip:
            # Mirror the gaze prediction around the y-axis
            gaze_prediction[:, 1] = -gaze_prediction[:, 1]

        # Detach gradients if network is frozen
        if config.st_net_frozen:
            gaze_prediction = gaze_prediction.detach()
            hidden_state = hidden_state.detach() if hidden_state is not None else None

        # return gaze_prediction.to("cpu"), hidden_state.to("cpu")
        return gaze_prediction, hidden_state
    
    def loss(self, input_dict, output_dict, reduction='mean'):
        """
        Compute angular error loss on gaze direction and MSE loss on PoG.

        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss: Dict, containing angular error loss and MSE loss on PoG.
        """
        loss_dict = {}
        batch, seq_len, _ = input_dict["left_PoG_tobii"].shape
        loss_dict["loss"] = torch.zeros((batch*seq_len)).to(device)
        for side in config.sides:
            gaze = input_dict[f"{side}_g_tobii"]
            gaze_prediction = output_dict[f"{side}_gaze"]
            loss_dict[f"{side}_loss_ang"] = angular_error(gaze, gaze_prediction, input_dict[f"{side}_g_tobii_validity"])
            loss_dict[f"{side}_loss_pog_px"] = PoG_loss(input_dict[f"{side}_PoG_tobii"], output_dict[f"{side}_PoG_px"], input_dict[f"{side}_PoG_tobii_validity"])
            pog_cm_tobii = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1*input_dict['millimeters_per_pixel'])
            loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(pog_cm_tobii, output_dict[f"{side}_PoG_mm"]*0.1, input_dict[f"{side}_PoG_tobii_validity"])

        if len(config.sides) == 2:
            gaze = (input_dict["left_g_tobii"] + input_dict["right_g_tobii"]) / 2.
            gaze_prediction = (output_dict["left_gaze"] + output_dict["right_gaze"]) / 2.
            validity = input_dict["left_g_tobii_validity"] * input_dict["right_g_tobii_validity"]
            loss_dict["loss_ang"] = (angular_error(gaze, gaze_prediction, validity))
            loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial

            pog_px = (input_dict["left_PoG_tobii"] + input_dict["right_PoG_tobii"]) / 2.
            pog_px_prediction = (output_dict["left_PoG_px"] + output_dict["right_PoG_px"]) / 2.
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_pog_px"] = PoG_loss(pog_px, pog_px_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial

            if "face_PoG_px" in output_dict:
                loss_dict["loss_face_pog_px"] = PoG_loss(input_dict["face_PoG_tobii"], output_dict["face_PoG_px"], input_dict["face_PoG_tobii_validity"])
                loss_dict["loss_face_ang"] = angular_error(input_dict["face_g_tobii"], output_dict["face_gaze"], input_dict["face_g_tobii_validity"])

            pog_cm = torch.mul(pog_px, 0.1*input_dict['millimeters_per_pixel'])
            pog_cm_prediction = (output_dict["left_PoG_mm"] + output_dict["right_PoG_mm"]) / 2. * 0.1
            loss_dict["loss_pog_cm"] = PoG_loss(pog_cm, pog_cm_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial

            loss_dict["loss_consistency"] = consistency_loss(output_dict["left_PoG_px"], output_dict["right_PoG_px"], validity)
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial

        loss_dict["loss"] = loss_dict["loss"] / (config.loss_coeff_g_ang_initial + config.loss_coeff_PoG_px_initial + config.loss_coeff_PoG_cons_initial + config.loss_coeff_PoG_cm_initial)
        if reduction == 'mean':
            # Average loss over batch before returning
            return {k: v.mean() for k, v in loss_dict.items()} #
        elif reduction == 'sum':
            # Sum loss over batch before returning
            return {k: v.sum() for k, v in loss_dict.items()}
        else:
            # No reduction
            return loss_dict
        
    def loss_eyes(self, input_dict, output_dict, reduction='mean'):
        """
        Compute angular error loss on gaze direction and MSE loss on PoG.

        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss: Dict, containing angular error loss and MSE loss on PoG.
        """
        loss_dict = {}
        batch, seq_len, _ = input_dict["left_PoG_tobii"].shape
        loss_dict["loss"] = torch.zeros((batch*seq_len)).to(device)
        for side in config.sides:
            gaze = input_dict[f"{side}_g_tobii"]
            gaze_prediction = output_dict[f"{side}_gaze"]
            loss_dict[f"{side}_loss_ang"] = angular_error(gaze, gaze_prediction, input_dict[f"{side}_g_tobii_validity"])
            loss_dict[f"{side}_loss_pog_px"] = PoG_loss(input_dict[f"{side}_PoG_tobii"], output_dict[f"{side}_PoG_px"], input_dict[f"{side}_PoG_tobii_validity"])
            loss_dict["loss"] += loss_dict[f"{side}_loss_pog_px"] * config.loss_coeff_PoG_px_initial * config.loss_coeff_g_eyes
            pog_cm_tobii = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1*input_dict['millimeters_per_pixel'])
            loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(pog_cm_tobii, output_dict[f"{side}_PoG_mm"]*0.1, input_dict[f"{side}_PoG_tobii_validity"])
            loss_dict["loss"] += loss_dict[f"{side}_loss_ang"] * config.loss_coeff_g_ang_initial * config.loss_coeff_g_eyes

        if len(config.sides) == 2:
            gaze = (input_dict["left_g_tobii"] + input_dict["right_g_tobii"]) / 2.
            gaze = input_dict["face_g_tobii"] if "face_g_tobii" in input_dict else gaze
            gaze_prediction = (output_dict["left_gaze"] + output_dict["right_gaze"]) / 2.
            validity = input_dict["left_g_tobii_validity"] * input_dict["right_g_tobii_validity"]
            validity = validity * input_dict["face_g_tobii_validity"] if "face_g_tobii_validity" in input_dict else validity
            loss_dict["loss_ang"] = (angular_error(gaze, gaze_prediction, validity))
            loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial * config.loss_coeff_g_face

            pog_px = (input_dict["left_PoG_tobii"] + input_dict["right_PoG_tobii"]) / 2.
            pog_px_prediction = (output_dict["left_PoG_px"] + output_dict["right_PoG_px"]) / 2.
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_pog_px"] = PoG_loss(pog_px, pog_px_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial * config.loss_coeff_g_face

            if "face_PoG_px" in output_dict:
                loss_dict["loss_face_pog_px"] = PoG_loss(input_dict["face_PoG_tobii"], output_dict["face_PoG_px"], input_dict["face_PoG_tobii_validity"])
                loss_dict["loss_face_ang"] = angular_error(input_dict["face_g_tobii"], output_dict["face_gaze"], input_dict["face_g_tobii_validity"])

            pog_cm = torch.mul(pog_px, 0.1*input_dict['millimeters_per_pixel'])
            pog_cm_prediction = (output_dict["left_PoG_mm"] + output_dict["right_PoG_mm"]) / 2. * 0.1
            loss_dict["loss_pog_cm"] = PoG_loss(pog_cm, pog_cm_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial

            loss_dict["loss_consistency"] = consistency_loss(output_dict["left_PoG_px"], output_dict["right_PoG_px"], validity)
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial

        loss_dict["loss"] = loss_dict["loss"] / (config.loss_coeff_g_ang_initial + config.loss_coeff_PoG_px_initial + config.loss_coeff_PoG_cons_initial + config.loss_coeff_PoG_cm_initial)
        if reduction == 'mean':
            # Average loss over batch before returning
            return {k: v.mean() for k, v in loss_dict.items()} #
        elif reduction == 'sum':
            # Sum loss over batch before returning
            return {k: v.sum() for k, v in loss_dict.items()}
        else:
            # No reduction
            return loss_dict

        
class STGazeNetCombined(nn.Module):
    """
    ST gaze network that uses EfficientNet encoder for eye and face patches.
    Concatenates the features and passes them through a channel attention module (ECA), 
    self-attention module, GRU cell and a fully connected layer.
    Difference from STGazeNet: take both eye_patchs and the face_patch as inputs at the same time.
    Still predicts 2 gaze directions, one for each eye (+ face) then average them.
    Built in the same syle as Ablation version for modularity.
    
    Args:
        model_name: str, model name for EfficientNet encoder.
        dropout: float, dropout rate.
        eye_encoder: bool, whether to use eye encoder.
        face_encoder: bool, whether to use face encoder.
        eca: bool, whether to use ECA module.
        sam: bool, whether to use self-attention module.
        gru: bool, whether to use GRU cell.

    Returns:
        Gaze prediction in radians.
    """
    def __init__(self, model_name=None, dropout=None, eye_encoder=True, face_encoder=True, eca=True, sam=True, gru=True):
        assert eye_encoder or face_encoder, "At least one of eye_encoder or face_encoder must be True."
        super(STGazeNetCombined, self).__init__()
        self.model_name = model_name or config.st_net_model_name
        self.dropout = dropout or config.st_net_dropout
        self.eye_encoder = None
        self.face_encoder = None

        self.early_fusion = config.early_fusion
        eye_feature_dim = 0
        face_feature_dim = 0

        if self.model_name.startswith("efficientnet"):
            if eye_encoder:
                self.eye_encoder = EfficientNetEncoder(self.model_name, eye=True, dropout=self.dropout)
                eye_feature_dim = eye_configs["last_channel"]
            if face_encoder:
                self.face_encoder = EfficientNetEncoder(self.model_name, eye=False, dropout=self.dropout)
                face_feature_dim = face_configs["last_channel"]
        elif self.model_name.startswith("resnet"):
            if eye_encoder:
                self.eye_encoder = ResNetEncoder(self.model_name, eye=True)
                eye_feature_dim = RESNET_FEATURE_DIM["eye_encoder"]
            if face_encoder:
                self.face_encoder = ResNetEncoder(self.model_name, eye=False)
                face_feature_dim = RESNET_FEATURE_DIM["face_encoder"]
        else:
            raise ValueError(f"Unsupported model name: {self.model_name}")
        
        # ---- Conditionally calculate the feature dimension for downstream layers ----
        if self.early_fusion and self.eye_encoder is not None:
            # Early Fusion: Sum of all features that will be concatenated
            num_eyes = len(config.sides)
            self.feature_dim = face_feature_dim + (num_eyes * eye_feature_dim)
        else:
            # Late Fusion (original logic): Face features + features from ONE eye
            self.feature_dim = face_feature_dim + eye_feature_dim

        self.seq_len = 64
        self.eca = ECA_Module(self.feature_dim) if eca else None
        self.sam = SelfAttentionModule(
            feature_dim=self.feature_dim,
            seq_len=self.seq_len,
            ff_dim=config.st_net_transformer_ffn_dim,
            dropout=self.dropout,
            num_heads=config.st_net_transformer_num_heads,
            num_layers=config.st_net_transformer_num_layers
        ) if sam else None
        self.gru = nn.GRU(
            input_size=self.feature_dim,
            hidden_size=self.feature_dim,
            num_layers=config.st_net_gru_num_cells,
            batch_first=True
        ) if gru else None
        # Add average pooling layer
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc_to_gaze = nn.Sequential(
            nn.Linear(self.feature_dim, config.st_net_static_num_features),
            nn.SELU(inplace=True),
            nn.Linear(config.st_net_static_num_features, 2, bias=False),
            nn.Tanh(),
        )
        # Set gaze layer weights to small non-zero values as otherwise this can explode early in training
        nn.init.constant_(self.fc_to_gaze[2].weight, 1e-4)
    
    def forward(self, input_dict, output_dict=None, previous_output_dict=None, get_pog=True):
        """
        Forward pass of the STGazeNetCombined network.
        Args:
            input_dict (dict): Dictionary containing input tensors as described below.
            - {left/right}_eye_patch (Tensor): For training, (B, T, 3, 128, 128). For inference, (1, 1, 3, 128, 128).
            - face_patch (Tensor): Same shape as eye_patch.
            - face_features (Tensor, optional): Precomputed face features. Shape (B, C, H, W).
            output_dict (dict, optional): Dictionary to store output tensors.
            previous_output_dict (dict, optional): Dictionary containing previous output tensors.
            get_pog (bool): Whether to compute and return Pupil of Gaze (PoG) predictions.
        Returns:
            output_dict (dict): Dictionary containing predicted gaze and PoG.
        """
        if output_dict is None:
            output_dict = {}
        
        # --- EARLY FUSION LOGIC ---
        if self.early_fusion and self.eye_encoder is not None:
            # 1. Collect all available features
            all_features = []
            
            # We still need individual eye outputs for the loss function
            for side in config.sides:
                B, T, C, H, W = input_dict[f"{side}_eye_patch"].shape
                all_features.append(
                    self.eye_encoder(
                        input_dict[f"{side}_eye_patch"].view(B * T, C, H, W)
                    )
                )
            
            if self.face_encoder is not None:
                B, T, C, H, W = input_dict["face_patch"].shape
                all_features.append(self.face_encoder(input_dict["face_patch"].view(B * T, C, H, W)))

            # 2. Concatenate all features and process them ONCE for the main prediction
            features = torch.cat(all_features, dim=1)
            
            if self.eca is not None:
                features = self.eca(features)
            if self.sam is not None:
                features = self.sam(features)
            else:
                features = features.flatten(start_dim=2).transpose(1, 2)
            
            if self.gru is not None:
                hidden_state = previous_output_dict.get("face_hidden_state", None) if previous_output_dict is not None else None
                if config.st_net_pool_before_gru:
                    features = self.avgpool(features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                    features = features.reshape(B, T, -1) # (B*T, C) -> (B, T, C)
                    features, hidden_state = self.gru(features, hidden_state)
                    features = features.reshape(B * T, -1) # (B, T, C) -> (B*T, C)
                else:
                    _, S, C = features.shape
                    features = features.view(B, T, S, C).flatten(start_dim=1, end_dim=2) # (B*T, S, C) -> (B, T, S, C) -> (B, T*S, C) for GRU input
                    features, hidden_state = self.gru(features, hidden_state)
                    features = features.view(B, T, S, C) # (B, T*S, C) -> (B, T, S, C)
                    features = features.reshape(B * T, S, C) # (B, T, S, C) -> (B*T, S, C)
                    features = self.avgpool(features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                output_dict["face_hidden_state"] = hidden_state
            else:
                features = self.avgpool(features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
            gaze_prediction = math.pi / 2 * self.fc_to_gaze(features)
            gaze_prediction = gaze_prediction.view(B, T, -1) # Reshape to (B, T, 2)
            output_dict["face_gaze"] = gaze_prediction

            # 3. Create individual eye predictions required by the loss function
            for side in config.sides:
                side_gaze = gaze_prediction.clone()
                output_dict[f"{side}_gaze"] = side_gaze
                output_dict[f"{side}_pupil_size"] = torch.zeros((gaze_prediction.size(0), 1)).to(device)

        # --- LATE FUSION LOGIC (Original Code, with fixes) ---
        else:
            if self.face_encoder is not None:
                B, T, C, H, W = input_dict["face_patch"].shape
                face_features = self.face_encoder(input_dict["face_patch"].view(B * T, C, H, W))
            
            if self.eye_encoder is not None:
                B, T, C, H, W = input_dict[f"{config.sides[0]}_eye_patch"].shape
                output_dict["face_gaze"] = torch.zeros((B, T, 2)).to(device)
                for side in config.sides:
                    B, T, C, H, W = input_dict[f"{side}_eye_patch"].shape
                    # Original logic: combine face with each eye separately
                    side_features = [self.eye_encoder(input_dict[f"{side}_eye_patch"].view(B * T, C, H, W)), face_features]
                    side_features = torch.cat(side_features, dim=1)
                    
                    if self.eca is not None:
                        side_features = self.eca(side_features)
                    if self.sam is not None:
                        side_features = self.sam(side_features)
                    else:
                        side_features = side_features.flatten(start_dim=2).transpose(1, 2)
                    
                    if self.gru is not None:
                        hidden_state = previous_output_dict.get(f"{side}_hidden_state", None) if previous_output_dict is not None else None
                        if config.st_net_pool_before_gru:
                            side_features = self.avgpool(side_features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                            side_features = side_features.reshape(B, T, -1) # (B*T, C) -> (B, T, C)
                            side_features, hidden_state = self.gru(side_features, hidden_state)
                            side_features = side_features.reshape(B * T, -1) # (B, T, C) -> (B*T, C)
                        else:
                            _, S, C = side_features.shape
                            side_features = side_features.view(B, T, S, C).flatten(start_dim=1, end_dim=2) # (B*T, S, C) -> (B, T, S, C) -> (B, T*S, C) for GRU input
                            side_features, hidden_state = self.gru(side_features, hidden_state)
                            side_features = side_features.view(B, T, S, C) # (B, T*S, C) -> (B, T, S, C)
                            side_features = side_features.reshape(B * T, S, C) # (B, T, S, C) -> (B*T, S, C)
                            side_features = self.avgpool(side_features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                        output_dict[f"{side}_hidden_state"] = hidden_state
                    else:
                        side_features = self.avgpool(side_features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                    gaze_prediction = math.pi / 2 * self.fc_to_gaze(side_features)
                    
                    if side == "right" and config.flip_right_eye:
                        gaze_prediction[:, 1] = -gaze_prediction[:, 1]
                    gaze_prediction = gaze_prediction.view(B, T, -1) # Reshape to (B, T, 2)
                    output_dict[f"{side}_gaze"] = gaze_prediction
                    
                    output_dict[f"{side}_pupil_size"] = torch.zeros((gaze_prediction.size(0), 1)).to(device)
                    output_dict["face_gaze"] += gaze_prediction 
                
                output_dict["face_gaze"] /= len(config.sides)

            else: # Only face encoder is active
                features = features[0]
                if self.eca is not None:
                    features = self.eca(features)
                if self.sam is not None:
                    features = self.sam(features)
                else:
                    features = features.flatten(start_dim=2).transpose(1, 2)
                
                if self.gru is not None:
                    hidden_state = previous_output_dict.get("face_hidden_state", None) if previous_output_dict is not None else None
                    if config.st_net_pool_before_gru:
                        features = self.avgpool(features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                        features = features.reshape(B, T, -1) # (B*T, C) -> (B, T, C)
                        features, hidden_state = self.gru(features, hidden_state)
                        features = features.reshape(B * T, -1) # (B, T, C) -> (B*T, C)
                    else:
                        _, S, C = features.shape
                        features = features.view(B, T, S, C).flatten(start_dim=1, end_dim=2) # (B*T, S, C) -> (B, T, S, C) -> (B, T*S, C) for GRU input
                        features, hidden_state = self.gru(features, hidden_state)
                        features = features.view(B, T, S, C) # (B, T*S, C) -> (B, T, S, C)
                        features = features.reshape(B * T, S, C) # (B, T, S, C) -> (B*T, S, C)
                        features = self.avgpool(features.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
                    if hidden_state is not None:
                        output_dict["face_hidden_state"] = hidden_state
                else:
                    features = self.avgpool(features.transpose(1, 2)).squeeze(-1)
                gaze_prediction = math.pi / 2 * self.fc_to_gaze(features)
                output_dict["face_gaze"] = gaze_prediction
                # Add placeholders
                output_dict["left_pupil_size"] = torch.zeros((gaze_prediction.size(0), 1)).to(device)
                output_dict["right_pupil_size"] = torch.zeros((gaze_prediction.size(0), 1)).to(device)

        # --- PoG Calculations (common to both strategies) ---
        if get_pog:
            # Calculate PoG for the final combined gaze
            if "face_gaze" in output_dict:
                origin = input_dict["face_o"].view(B*T, 3)
                rotation = input_dict["face_R"].view(B*T, 3, 3)
                inv_camera_transformation = input_dict["inv_camera_transformation"].view(B*T, 4, 4)
                pixels_per_millimeter = input_dict["pixels_per_millimeter"].view(B*T, 2)
                PoG_mm, PoG_px = to_screen_coordinates(
                    origin, 
                    output_dict["face_gaze"].view(B*T, 2), 
                    rotation, 
                    inv_camera_transformation, 
                    pixels_per_millimeter
                )
                output_dict["face_PoG_mm"] = PoG_mm.view(B, T, 2)
                output_dict["face_PoG_px"] = PoG_px.view(B, T, 2)
            
            # Calculate PoG for individual eyes (needed for loss)
            for side in config.sides:
                if f"{side}_gaze" in output_dict:
                    origin = input_dict[f"{side}_o"].view(B*T, 3)
                    rotation = input_dict[f"{side}_R"].view(B*T, 3, 3)
                    inv_camera_transformation = input_dict["inv_camera_transformation"].view(B*T, 4, 4)
                    pixels_per_millimeter = input_dict["pixels_per_millimeter"].view(B*T, 2)
                    PoG_mm, PoG_px = to_screen_coordinates(
                        origin, 
                        output_dict[f"{side}_gaze"].view(B*T, 2),
                        rotation, 
                        inv_camera_transformation, 
                        pixels_per_millimeter
                    )
                    output_dict[f"{side}_PoG_mm"] = PoG_mm.view(B, T, 2)
                    output_dict[f"{side}_PoG_px"] = PoG_px.view(B, T, 2)
                    
        # Detach gradients if the whole network is frozen
        if config.st_net_frozen:
            for k, v in output_dict.items():
                if torch.is_tensor(v) and v.requires_grad:
                    output_dict[k] = v.detach()

        return output_dict
    
    def loss(self, input_dict, output_dict, reduction='mean'):
        """
        Computes loss. The primary training objective is to match the 'face' ground truth.
        Auxiliary losses (individual eyes, consistency) act as regularizers or are for logging.
        
        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss_dict: Dict, containing individual and combined loss values.
        """
        loss_dict = {}
        total_coeff = 0.0

        if self.eye_encoder is not None and config.sides:
            sample_key = f"{config.sides[0]}_PoG_tobii"
        elif self.face_encoder is not None:
            sample_key = "face_PoG_tobii"
        else:
            raise ValueError("No valid camera frame types ('eyes' or 'face') found in config.")
        
        batch, seq_len, _ = input_dict[sample_key].shape
        device = input_dict[sample_key].device

        # --- FACE LOSSES  ---
        
        # Angular Loss
        loss_dict["face_loss_ang"] = angular_error(
            input_dict["face_g_tobii"], 
            output_dict["face_gaze"], 
            input_dict["face_g_tobii_validity"]
        )

        # PoG Loss in Pixels
        loss_dict["face_loss_pog_px"] = PoG_loss(
            input_dict["face_PoG_tobii"], 
            output_dict["face_PoG_px"], 
            input_dict["face_PoG_tobii_validity"]
        )

        # PoG Loss in Centimeters
        pog_cm_gt = torch.mul(input_dict["face_PoG_tobii"], 0.1 * input_dict['millimeters_per_pixel'])
        loss_dict["face_loss_pog_cm"] = PoG_loss(
            pog_cm_gt,
            output_dict["face_PoG_mm"] * 0.1,
            input_dict["face_PoG_tobii_validity"]
        )

        # --- EYES LOSSES ---
        if self.eye_encoder is not None:
            avg_pog_pred_px = torch.zeros((batch, seq_len, 2), device=device)
            avg_pog_true_px = torch.zeros((batch,seq_len, 2), device=device)
            val_pog = torch.ones((batch,seq_len), device=device)
            avg_gaze_pred = torch.zeros((batch, seq_len, 2), device=device)
            avg_gaze_true = torch.zeros((batch, seq_len, 2), device=device)
            val_gaze = torch.ones((batch, seq_len), device=device)
            for side in config.sides:
                avg_pog_pred_px += output_dict[f"{side}_PoG_px"] / len(config.sides)
                avg_pog_true_px += input_dict[f"{side}_PoG_tobii"] / len(config.sides)
                val_pog *= input_dict[f"{side}_PoG_tobii_validity"]
                avg_gaze_pred += output_dict[f"{side}_gaze"] / len(config.sides)
                avg_gaze_true += input_dict[f"{side}_g_tobii"] / len(config.sides)
                val_gaze *= input_dict[f"{side}_g_tobii_validity"]
                # Angular Loss
                loss_dict[f"{side}_loss_ang"] = angular_error(
                    input_dict[f"{side}_g_tobii"], 
                    output_dict[f"{side}_gaze"], 
                    input_dict[f"{side}_g_tobii_validity"]
                )
                # PoG Loss in Pixels
                loss_dict[f"{side}_loss_pog_px"] = PoG_loss(
                    input_dict[f"{side}_PoG_tobii"], 
                    output_dict[f"{side}_PoG_px"], 
                    input_dict[f"{side}_PoG_tobii_validity"]
                )
                # PoG Loss in Centimeters
                pog_cm_gt = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1 * input_dict['millimeters_per_pixel'])
                loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(
                    pog_cm_gt,
                    output_dict[f"{side}_PoG_mm"] * 0.1,
                    input_dict[f"{side}_PoG_tobii_validity"]
                )
            # Average losses across both eyes
            # Angular Loss
            loss_dict["loss_ang"] = angular_error(
                avg_gaze_true,
                avg_gaze_pred,
                val_gaze
            )
            # PoG Loss in Pixels
            loss_dict["loss_pog_px"] = PoG_loss(
                avg_pog_true_px, 
                avg_pog_pred_px,
                val_pog
            )
            # PoG Loss in Centimeters
            avg_pog_pred_cm = torch.mul(avg_pog_pred_px, 0.1 * input_dict['millimeters_per_pixel'])
            avg_pog_true_cm = torch.mul(avg_pog_true_px, 0.1 * input_dict['millimeters_per_pixel'])
            loss_dict["loss_pog_cm"] = PoG_loss(
                avg_pog_true_cm, 
                avg_pog_pred_cm,
                val_pog
            )
        else:
            # If no eye encoder, use face losses directly
            loss_dict["loss_ang"] = loss_dict["face_loss_ang"]
            loss_dict["loss_pog_px"] = loss_dict["face_loss_pog_px"]
            loss_dict["loss_pog_cm"] = loss_dict["face_loss_pog_cm"]

        # --- CALCULATE TOTAL LOSS ---
        loss_dict["loss"] = torch.zeros(batch * seq_len, device=device)

        loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial
        total_coeff += config.loss_coeff_g_ang_initial
        loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial
        total_coeff += config.loss_coeff_PoG_px_initial
        loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial
        total_coeff += config.loss_coeff_PoG_cm_initial

        # --- AUXILIARY CONSISTENCY LOSS (Regularizer) ---
        # This loss does not use a ground truth. It encourages the two eye
        # predictions to be consistent with each other, acting as a regularizer.
        if self.eye_encoder is not None and len(config.sides) == 2 and not config.early_fusion:
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_consistency"] = consistency_loss(
                output_dict["left_PoG_px"], 
                output_dict["right_PoG_px"], 
                validity
            )
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial
            total_coeff += config.loss_coeff_PoG_cons_initial

        # Finalize loss and apply reduction
        if total_coeff > 0:
            loss_dict["loss"] /= total_coeff

        if reduction == 'mean':
            return {k: v.mean() for k, v in loss_dict.items() if torch.is_tensor(v)}
        elif reduction == 'sum':
            return {k: v.sum() for k, v in loss_dict.items() if torch.is_tensor(v)}
        else: # reduction is None
            return loss_dict

class STGazeNetVectorized(nn.Module):
    def __init__(self, model_name=None, dropout=None, eye_encoder=True, face_encoder=True, eca=True, sam=True, gru=True):
        super(STGazeNetVectorized, self).__init__()
        self.model_name = model_name or config.st_net_model_name
        self.dropout = dropout or config.st_net_dropout
        if self.model_name.startswith("efficientnet"):
            self.eye_encoder = EfficientNetEncoder(self.model_name, eye=True, dropout=self.dropout) if eye_encoder else None
            self.face_encoder = EfficientNetEncoder(self.model_name, eye=False, dropout=self.dropout) if face_encoder else None
            self.feature_dim = eye_configs["last_channel"] * eye_encoder + face_configs["last_channel"] * face_encoder
        elif self.model_name.startswith("resnet"):
            self.eye_encoder = ResNetEncoder(self.model_name, eye=True)
            self.face_encoder = ResNetEncoder(self.model_name, eye=False)
            self.feature_dim = RESNET_FEATURE_DIM["eye_encoder"] + RESNET_FEATURE_DIM["face_encoder"]
            
        self.seq_len = 64
        self.eca = ECA_Module(self.feature_dim) if eca else None
        self.sam = SelfAttentionModule(
            feature_dim=self.feature_dim,
            seq_len=self.seq_len,
            ff_dim=config.st_net_transformer_ffn_dim,
            dropout=self.dropout,
            num_heads=config.st_net_transformer_num_heads,
            num_layers=config.st_net_transformer_num_layers
        ) if sam else None
        self.gru = nn.GRU(
            input_size=self.feature_dim,
            hidden_size=self.feature_dim,
            num_layers=config.st_net_gru_num_cells,
            batch_first=True
        ) if gru else None
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc_to_gaze = nn.Sequential(
            nn.Linear(self.feature_dim, config.st_net_static_num_features),
            nn.SELU(inplace=True),
            nn.Linear(config.st_net_static_num_features, 2, bias=False),
            nn.Tanh(),
        )
        # Set gaze layer weights to small non-zero values as otherwise this can explode early in training
        nn.init.constant_(self.fc_to_gaze[2].weight, 1e-4)
        
    def forward(self, full_input_dict, side, hidden_state_in=None, flip=None):
        """
        Flexible forward pass for both training and real-time inference.

        Args:
            full_input_dict (dict): Dictionary containing input tensors as described below.
            - {side}_eye_patch (Tensor): For training, (B, T, 3, 128, 128). For inference, (1, 1, 3, 128, 128).
            - face_patch (Tensor): Same shape as eye_patch.
            - face_features (Tensor, optional): Precomputed face features. Shape (B, C, H, W).
            side (str): 'left' or 'right', indicating which eye to process.
            hidden_state_in (Tensor, optional): Previous hidden state for inference. Shape (num_layers, B, hidden_size).
                                                Defaults to None for training.
            flip: Bool, whether to flip the gaze prediction.
        Returns:
            gaze_prediction_sequence (Tensor): Predicted gaze. Shape (B, T, 2).
            hidden_state_out (Tensor): Final hidden state. Shape (num_layers, B, hidden_size).
        """
        B, T, C, H, W = full_input_dict[f"{side}_eye_patch"].shape

        # Steps 1-4: Feature Extraction
        features = []
        if self.eye_encoder is not None:
            eye_patch_flat = full_input_dict[f"{side}_eye_patch"].view(B * T, C, H, W)
            features.append(self.eye_encoder(eye_patch_flat))
        if self.face_encoder is not None:
            if "face_features" in full_input_dict:
                features.append(full_input_dict["face_features"])
            else:
                face_patch_flat = full_input_dict["face_patch"].view(B * T, C, H, W)
                full_input_dict["face_features"] = self.face_encoder(face_patch_flat)
                features.append(full_input_dict["face_features"])

        features = torch.cat(features, dim=1)
        if self.eca is not None:
            features = self.eca(features)
        if self.sam is not None:
            features = self.sam(features)
        else:
            features = features.flatten(start_dim=2).transpose(1, 2)
        if self.gru is not None:
            if config.st_net_pool_before_gru:
                features = self.avgpool(features.transpose(1, 2)).squeeze(-1)
                
                # Step 5: Reshape for Temporal Processing
                gru_input = features.view(B, T, -1)
                
                # Step 6: Process through GRU
                gru_output, hidden_state_out = self.gru(gru_input, hidden_state_in)

                # Step 7: Reshape for Final Prediction Layer
                fc_input = gru_output.reshape(B * T, -1)
            else:
                # Step 5: Reshape for Temporal Processing
                _, S, C = features.shape
                gru_input = features.view(B, T, S, C).flatten(start_dim=1, end_dim=2) # (B*T, S, C) -> (B, T, S, C) -> (B, T*S, C) for GRU input
                # Step 6: Process through GRU
                gru_output, hidden_state_out = self.gru(gru_input, hidden_state_in)
                # Step 7: Average Pooling and Reshape
                fc_input = gru_output.view(B, T, S, C) # (B, T*S, C) -> (B, T, S, C)
                fc_input = fc_input.reshape(B * T, S, C) # (B, T, S, C) -> (B*T, S, C)
                fc_input = self.avgpool(fc_input.transpose(1, 2)).squeeze(-1) # (B*T, S, C) -> (B*T, C)
        else:
            # If GRU is not used, we simply average pool the features
            fc_input = self.avgpool(features.transpose(1, 2)).squeeze(-1)
        # Step 8: Final Fully Connected Layer
        gaze_prediction = self.fc_to_gaze(fc_input)
        gaze_prediction = math.pi / 2 * gaze_prediction
        
        # Step 8: Handle Flipping
        if flip:
            # Mirror the gaze prediction around the y-axis
            gaze_prediction[:, 1] = -gaze_prediction[:, 1]

        # Step 9: Reshape Output
        gaze_prediction_sequence = gaze_prediction.view(B, T, 2)

        return gaze_prediction_sequence, hidden_state_out
    
    def loss(self, input_dict, output_dict, reduction='mean'):
        """
        Compute angular error loss on gaze direction and MSE loss on PoG.

        Args:
            input_dict: dict, input dictionary containing ground truth gaze.
            output_dict: dict, output dictionary containing predicted gaze.
            reduction: str, reduction method for loss ('mean', 'sum', None).

        Returns:
            loss: Dict, containing angular error loss and MSE loss on PoG.
        """
        loss_dict = {}
        batch, seq_len, _ = input_dict["left_PoG_tobii"].shape
        loss_dict["loss"] = torch.zeros((batch*seq_len)).to(device)
        total_coeff = 0.0
        for side in config.sides:
            gaze = input_dict[f"{side}_g_tobii"]
            gaze_prediction = output_dict[f"{side}_gaze"]
            loss_dict[f"{side}_loss_ang"] = angular_error(gaze, gaze_prediction, input_dict[f"{side}_g_tobii_validity"])
            loss_dict[f"{side}_loss_pog_px"] = PoG_loss(input_dict[f"{side}_PoG_tobii"], output_dict[f"{side}_PoG_px"], input_dict[f"{side}_PoG_tobii_validity"])
            pog_cm_tobii = torch.mul(input_dict[f"{side}_PoG_tobii"], 0.1*input_dict['millimeters_per_pixel'])
            loss_dict[f"{side}_loss_pog_cm"] = PoG_loss(pog_cm_tobii, output_dict[f"{side}_PoG_mm"]*0.1, input_dict[f"{side}_PoG_tobii_validity"])

        if len(config.sides) == 2:
            gaze = (input_dict["left_g_tobii"] + input_dict["right_g_tobii"]) / 2.
            gaze_prediction = (output_dict["left_gaze"] + output_dict["right_gaze"]) / 2.
            validity = input_dict["left_g_tobii_validity"] * input_dict["right_g_tobii_validity"]
            loss_dict["loss_ang"] = (angular_error(gaze, gaze_prediction, validity))
            loss_dict["loss"] += loss_dict["loss_ang"] * config.loss_coeff_g_ang_initial
            total_coeff += config.loss_coeff_g_ang_initial

            pog_px = (input_dict["left_PoG_tobii"] + input_dict["right_PoG_tobii"]) / 2.
            pog_px_prediction = (output_dict["left_PoG_px"] + output_dict["right_PoG_px"]) / 2.
            validity = input_dict["left_PoG_tobii_validity"] * input_dict["right_PoG_tobii_validity"]
            loss_dict["loss_pog_px"] = PoG_loss(pog_px, pog_px_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_px"] * config.loss_coeff_PoG_px_initial
            total_coeff += config.loss_coeff_PoG_px_initial

            pog_cm = torch.mul(pog_px, 0.1*input_dict['millimeters_per_pixel'])
            pog_cm_prediction = (output_dict["left_PoG_mm"] + output_dict["right_PoG_mm"]) / 2. * 0.1
            loss_dict["loss_pog_cm"] = PoG_loss(pog_cm, pog_cm_prediction, validity)
            loss_dict["loss"] += loss_dict["loss_pog_cm"] * config.loss_coeff_PoG_cm_initial
            total_coeff += config.loss_coeff_PoG_cm_initial

            if "face_PoG_px" in output_dict:
                loss_dict["face_loss_ang"] = angular_error(input_dict["face_g_tobii"], output_dict["face_gaze"], input_dict["face_g_tobii_validity"])
                loss_dict["face_loss_pog_px"] = PoG_loss(input_dict["face_PoG_tobii"], output_dict["face_PoG_px"], input_dict["face_PoG_tobii_validity"])
                loss_dict["face_loss_pog_cm"] = PoG_loss(
                    torch.mul(input_dict["face_PoG_tobii"], 0.1 * input_dict['millimeters_per_pixel']),
                    output_dict["face_PoG_mm"] * 0.1,
                    input_dict["face_PoG_tobii_validity"]
                )

            loss_dict["loss_consistency"] = consistency_loss(output_dict["left_PoG_px"], output_dict["right_PoG_px"], validity)
            loss_dict["loss"] += loss_dict["loss_consistency"] * config.loss_coeff_PoG_cons_initial
            total_coeff += config.loss_coeff_PoG_cons_initial

        if total_coeff > 0:
            loss_dict["loss"] /= total_coeff
        else:
            raise ValueError("Total coefficient for loss is zero, which may indicate no valid losses were computed.\n Available losses: " + ", ".join(loss_dict.keys()))
        if reduction == 'mean':
            # Average loss over batch before returning
            return {k: v.mean() for k, v in loss_dict.items()} #
        elif reduction == 'sum':
            # Sum loss over batch before returning
            return {k: v.sum() for k, v in loss_dict.items()}
        else:
            # No reduction
            return loss_dict