import numpy as np
import pywt

import sys

project_path = "..."
sys.path.append(project_path)

from utils import calculate_seed, generate_sequence_from_seed, is_region_allowed


def check_large_block_dwt(IMAGE_block, diag, interval_length, dwt_lvl, block_height, block_width, 
                          dct_block_size, num_regions_per_channel, scale, round=30, num_keys=0):
    mean_color = np.mean(IMAGE_block)
    rounded_mean_color = calculate_seed(mean_color, round=round)
    
    if num_keys == 0:
        random_sequence = generate_sequence_from_seed(int(rounded_mean_color), num_regions_per_channel, key=None)
        return process_dwt_blocks(IMAGE_block, diag, interval_length, dwt_lvl, block_height, block_width, 
                                  dct_block_size, random_sequence, scale)
    
    keys = [2468, 123, 7890, 55555, 999][:num_keys]
    min_check_counter_curr = block_height * block_width ## min violations
    print(min_check_counter_curr)
    
    for key in keys:
        random_sequence = generate_sequence_from_seed(int(rounded_mean_color), num_regions_per_channel, key=key)
        check_counter_curr, num_pixels = process_dwt_blocks(IMAGE_block, diag, interval_length, dwt_lvl, block_height, block_width, 
                                                   dct_block_size, random_sequence, scale)
        #print(f"Min: {min_check_counter_curr}; Curr: {check_counter_curr}; Num: {num_pixels}")
        min_check_counter_curr = min(min_check_counter_curr, check_counter_curr)
    
    return min_check_counter_curr, num_pixels

def process_dwt_blocks(IMAGE_block, diag, interval_length, dwt_lvl, block_height, block_width, 
                        dct_block_size, random_sequence, scale):
    check_counter_curr = 0
    num_pixels = 0
    
    for sub_y in range(0, block_height, dct_block_size):
        for sub_x in range(0, block_width, dct_block_size):
            sub_block_height = min(dct_block_size, block_height - sub_y)
            sub_block_width = min(dct_block_size, block_width - sub_x)
            
            if sub_block_height == sub_block_width == dct_block_size:
                sub_block = IMAGE_block[sub_y:sub_y + sub_block_height, sub_x:sub_x + sub_block_width]
                sub_block = sub_block.astype(np.float64)
                
                if dwt_lvl == 1:
                    LL, _ = pywt.wavedec2(sub_block, wavelet='haar', level=dwt_lvl)
                else:
                    coeffs = pywt.wavedec2(sub_block, wavelet='haar', level=dwt_lvl)
                    LL = coeffs[0]
                
                if diag > 0:
                    for i in range(0, min(diag, LL.shape[0])):
                        for j in range(0, min(diag, LL.shape[1]) - i):
                            coeff_value = LL[i, j]
                            num_pixels += 1
                            if not is_region_allowed(coeff_value, random_sequence, scale, interval_length):
                                check_counter_curr += 1
                elif diag == -1:
                    for i in range(0, LL.shape[0]):
                        for j in range(0, LL.shape[1] - i):
                            coeff_value = LL[i, j]
                            num_pixels += 1
                            if not is_region_allowed(coeff_value, random_sequence, scale, interval_length):
                                check_counter_curr += 1
    
    return check_counter_curr, num_pixels
