import numpy as np

proj_shape = (1080, 1920)
total_shape = (1080 * 2, 1920)

# Binarizes the input image based on the real component.
def binarize(img):
    real = np.real(img)
    bH = real > 0
    return bH * 1.0

# Generates the full target pattern for some input (because Fourier Transform of binary pattern is conjugate symmetric)
# "flip" variable controls which copy is to be optimized for - the top copy or the conjugate symmetric copy.
def gen_symmetric_target(target, flip=True):
    target_amp = np.zeros((total_shape[0], total_shape[1]))
    if len(target.shape) == 1:
        print("Unsupported 1d...")
        exit()
    else:
        if not flip:
            target_amp[:proj_shape[0], :proj_shape[1]] = target
        else:
            target_amp[proj_shape[0]:] = np.flipud(np.fliplr(target))

    return target_amp

# Binary Gerchberg Saxton
# "target" - target projector pattern
# "iters" - iterations
# "ret" - whether to return simulated output
# "use_dist" - whether to load phase distortion from file
# "flip" - see "gen_symmetric_target"
def binary_gs(target, iters=10, ret = False, use_dist=True, flip=True):
    target_int = gen_symmetric_target(target, flip)
    target_amp = np.sqrt(target_int)

    shape = target_amp.shape

    if use_dist:
        raw_phase = np.load("../distort/calib_dmd_phase.npy")

        calib_dist = np.zeros(shape, dtype=np.complex128)
        calib_dist[:raw_phase.shape[0], :raw_phase.shape[1]] = np.exp(1j * raw_phase)
    else:
        calib_dist = np.ones(shape)


    dc_loc = (total_shape[0] // 2, total_shape[1] // 2)

    bH = np.random.randint(0, 2, size=shape) * 1.0

    for _ in range(iters): 
        dmd_wavefront = bH * calib_dist
        T = np.fft.fftshift(np.fft.fftn(dmd_wavefront))
        T = target_amp * np.exp(1j * np.angle(T))

        H = np.fft.ifftn(np.fft.ifftshift(T)) * np.conj(calib_dist)
        bH = binarize(H)

    if bH.max() == 0:
        o = bH.astype(int)[:proj_shape[0], :proj_shape[1]]
    else:
        o = np.round(bH / bH.max()).astype(int)[:proj_shape[0], :proj_shape[1]]

    if ret:
        dmd_wavefront = bH * calib_dist
        out_T = np.fft.fftshift(np.abs(np.fft.fftn(dmd_wavefront))) ** 2
        out_T[dc_loc] = 0

        return o, out_T
    else:
        return o