import numpy as np
import matplotlib.pyplot as plt
import pywt


def find_closest_allowed_color(dct_coeff, sequence, interval_length, scale, adjust_to_center=False, max_change= 3):    
    # Calculate the region index
    region_idx = get_idx(dct_coeff, scale, interval_length)

    # Find the closest allowed direction (steps to move to the new cube)
    steps = find_closest_index(sequence, interval_length, region_idx, max_change)

    if steps != 0:
        new_idx = region_idx + steps
        new_dct_coeff = dct_coeff + steps * interval_length
        assert is_region_allowed(new_dct_coeff, sequence, scale, interval_length)

        # If adjust_to_center is True, move the new coefficient to the center of the cube
        if adjust_to_center:
            # get left and right border of the cube: 
            left_border = -(scale/2) + new_idx * interval_length
            right_border = -(scale/2) + (new_idx + 1) * interval_length
            # The center of the new cube is calculated:
            center_of_new_cube = (left_border + right_border)/2
            return center_of_new_cube
        else:
            # If not adjusting to center, return the new coefficient as calculated
            return new_dct_coeff

    # If no steps are found, return the original coefficient (no adjustment needed)
    return dct_coeff


def get_idx(coeff, scale, interval_length):
    starting_point = scale/2 * (-1) ## if scale = 4000 then we have points ranging from -2000 to +2000
    return (coeff - starting_point) // interval_length


def show_image(image, title="Image"):
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')  # Hide the axes
    plt.show()

def calculate_seed(mean_color, round=30):
    return np.round(mean_color / round) * round

def generate_sequence_from_seed(seed, num_regions_per_channel, key = None):
    #seed_int = int(seed[0]) * 1000000 + int(seed[1]) * 1000 + int(seed[2])
    if key == None:
        np.random.seed(seed)
        return np.random.randint(0, 2, size=num_regions_per_channel)
    else:
        # Use the selected key as the seed
        rng = np.random.default_rng(key)
        return rng.integers(0, 2, size=num_regions_per_channel)



def is_region_allowed(dct_coeff, sequence, scale, interval_length):
    region_index = get_idx(dct_coeff, scale, interval_length)
    if len(sequence) <= abs(region_index):
        return 1 # just return 1 if out of bounds
    return sequence[round(region_index)] == 1



def find_closest_index(sequence, interval_length, curr_index, max_change):
    # gets as input an index e.g. 5 (currently disallowed)
     # in 1-ser schritten von  -5 bis 5 also -5 schritte, -4 schritte usw  
    directions = [direction for direction in np.arange(-max_change, max_change + 1, 1)]
    sorted_directions = sorted(directions, key=abs)
    for direction in sorted_directions: 
        new_index = round(curr_index + direction) # nur + direction weil es nur der index ist 
        if sequence[new_index] == 1:
            return direction ## how many steps to make
    return 0



def inverse_dwt(LL, LH, HL, HH):
    # Reconstruct the image using inverse DWT
    coeffs = [LL, (LH, HL, HH)]
    return pywt.waverec2(coeffs, wavelet='haar')



