import numpy as np
import math
import sys

project_path = "..."
sys.path.append(project_path)

def eval_m1(violations_matrix, pixels_matrix, z_score):
    """
    Determines the percentage of large squares that are watermarked, excluding entries with
    an entropy lower than the specified threshold.
    """

    # Create a boolean mask for valid entries in pixels_matrix (excluding -1 values)
    valid_pixels_mask = (pixels_matrix != -1) & (violations_matrix != -1)

    # Calculate the violation threshold matrix only for valid entries
    violation_threshold = np.zeros_like(pixels_matrix, dtype=float)
    violation_threshold[valid_pixels_mask] = (
        (z_score / 2) * np.sqrt(pixels_matrix[valid_pixels_mask]) + 
        (pixels_matrix[valid_pixels_mask] / 2)
    )
    #print(f"Threshold: {violation_threshold}")


    # Create a boolean mask where the violations are less than the threshold for valid entries
    mask = (violations_matrix < violation_threshold) & valid_pixels_mask
    # print(mask)

    # Count the number of watermarked fields
    num_watermarked = np.sum(mask)
    print(f"num watermarked {num_watermarked}")

    # Calculate the total number of valid fields (excluding the ones set to -1)
    valid_squares = np.sum(valid_pixels_mask)

    # Return the percentage of watermarked fields
    if num_watermarked > 0 and valid_squares > 0:
        return (num_watermarked / valid_squares) * 100, mask
    else:
        return 0



############################################################################
def calculate_z_score(num_violations, num_pixels, p0=0.5):
    """
    Calculate the z-score for the proportion of violating pixels.

    Parameters:
    - num_violations (int): Number of violating pixels.
    - num_pixels (int): Total number of pixels.
    - p0 (float): Hypothesized proportion (default is 0.5).

    Returns:
    - float: The calculated z-score.
    """
    if num_pixels == 0:
        raise ValueError("Number of pixels must be greater than 0.")
    
    p_hat = num_violations / num_pixels
    standard_error = math.sqrt(p0 * (1 - p0) / num_pixels)
    
    if standard_error == 0:
        raise ValueError("Standard error is zero; check input values.")
    
    z = (p_hat - p0) / standard_error
    return z



############################################################################
def eval_m2(violations_matrix, pixels_matrix, z_score):
    """
    Determines the average violation
    """
    # Calculate the total sum of violations
    total_violations = np.sum(violations_matrix)
    
    # Calculate the total sum of pixels
    total_pixels = np.sum(pixels_matrix)
    
    # Return the percentage
    return (total_violations / total_pixels) * 100

