import torch
import torch.nn.functional as F
from .commons import pitchyaw_to_vector
from torch import pi

def angular_error(y_true, y_pred, validity):
    """
    Compute angular error between ground truth and predicted gaze vectors.
    
    Args:
        y_true: Tensor of shape (B, S, 2) containing pitchyaw angles.
        y_pred: Tensor of shape (B, S, 2) containing predicted pitchyaw angles.
        validity: Tensor of shape (B, S,) containing validity flags (0 or 1).
        
    Returns:
        ae: Tensor of shape (B*S,) containing angular errors in degrees.
    """
    # Convert pitchyaw angles to unit vectors
    v_true = pitchyaw_to_vector(y_true).reshape(-1, 3)
    # v_pred = v_pred.squeeze((1, 2))
    v_pred = pitchyaw_to_vector(y_pred).reshape(-1, 3)
    # Compute angular error in radians
    cosin = F.cosine_similarity(v_true, v_pred, dim=-1)
    cosin = torch.clamp(cosin, -1.0+1e-8, 1.0-1e-8)
    # cosin = F.hardtanh(cosin, min_val=-1.0+1e-8, max_val=1.0-1e-8)
    ae = torch.acos(cosin)
    # Convert to degrees
    ae = ae * 180.0 / pi
    # Reshape and apply validity mask
    validity = validity.reshape(-1)
    ae = ae * validity
    return ae

def PoG_loss(y_true, y_pred, validity):
    """
    Compute MSE loss between ground truth and predicted Point of Gaze (PoG).
    Data exprimed as on-screen pixel coordinates.

    Args:
        y_true: Tensor of shape (B, S, 2) containing PoG coordinates as estimated by the Tobii Pro Spectrum device.
        y_pred: Tensor of shape (B, S, 2) containing predicted PoG coordinates.
        validity: Tensor of shape (B, S,) containing validity flags (0 or 1).

    Returns:
        loss: Tensor of shape (B*S,) containing MSE.
    """
    mse = (F.mse_loss(y_true, y_pred, reduction='none').sum(dim=-1)).reshape(-1)
    # Reshape and apply validity mask
    validity = validity.reshape(-1)
    mse = mse * validity
    return mse

def consistency_loss(right_PoG, left_PoG, validity):
    """
    Compute consistency loss between left and right PoG predictions.
    Data exprimed as on-screen pixel coordinates.

    Args:
        right_PoG: Tensor of shape (B, S, 2) containing predicted PoG coordinates for the right eye.
        left_PoG: Tensor of shape (B, S, 2) containing predicted PoG coordinates for the left eye.
        validity: Tensor of shape (B, S,) containing validity flags (0 or 1).

    Returns:
        loss: Tensor of shape (B*S,) containing MSE.
    """
    mse = torch.sqrt(F.mse_loss(right_PoG, left_PoG, reduction='none').sum(dim=-1)).reshape(-1)
    # Reshape and apply validity mask
    validity = validity.reshape(-1)
    mse = mse * validity
    return mse