import numpy as np
import cv2
import pywt
import math
import sys

project_path = "..."
sys.path.append(project_path)

from utils import *
from detect_utils import check_large_block_dwt
from entropy_utils import compute_entropy

def detect_watermark(method, image_array, diag = -1, interval_length=1, dwt_lvl = -1, 
                     image_block_size=96, dct_block_size=8, scale=4000, entropy_threshold=0, round=10, num_keys = 0):
    """
    Divide the image into 96x96 blocks. Each pixel in a 96x96 block will use the mean color 
    of this large block to determine the allowed and disallowed functions. 
    Iterate over each large block in 8x8 blocks and:
    2. Apply DWT on each 8x8 block
    3. Check the number of violations
    """
    print(f"Detecting watermark for interval_length {interval_length}. Method: {method}")
    image_array = image_array.astype(np.float64)
    
    height, width = image_array.shape
    num_regions_per_channel = int(scale / interval_length)

    block_counter = 0
    # Calculate the number of blocks in y and x directions
    num_blocks_y = math.ceil(height / image_block_size)
    num_blocks_x = math.ceil(width / image_block_size)

    # Counter to keep track of processed blocks
    block_counter = 0
    check_counter = 0
    total_pixels = 0

    # Initialize matrices to store violations and pixels per block
    violations_matrix = np.zeros((num_blocks_y, num_blocks_x), dtype=float)
    pixels_matrix = np.zeros((num_blocks_y, num_blocks_x), dtype=float)

    # Loop over the image in larger 96x96 blocks
    for i, y in enumerate(range(0, height, image_block_size)):
        for j, x in enumerate(range(0, width, image_block_size)):
            block_height = min(image_block_size, height - y)
            block_width = min(image_block_size, width - x)
            
            # Extract the larger 96x96 block
            block = image_array[y:y + block_height, x:x + block_width]

            if compute_entropy(block) >= entropy_threshold:

                ### wieder beim 96x96 block angekommen
                num_violations, num_pixels = check_large_block_dwt(block, diag, interval_length, dwt_lvl, block_height, \
                        block_width, dct_block_size, num_regions_per_channel, scale, round, num_keys=num_keys) ###

                # Store the violations and pixels in the matrices
                violations_matrix[i, j] = num_violations
                pixels_matrix[i, j] = num_pixels

                total_pixels += num_pixels
                check_counter += num_violations
                block_counter += 1
            
            else: 
                # -1 indicates that the entropy was lower
                violations_matrix[i, j] = -1
                pixels_matrix[i, j] = -1

    if total_pixels != 0:
        overall_violation_percentage = (check_counter / total_pixels) * 100
    else:
        print("All blocks had lower entropy!")
        overall_violation_percentage = 0
    
    return overall_violation_percentage, violations_matrix, pixels_matrix
