import tensorflow as tf
import numpy as np
import cv2, georeg, tfcv

def draw_correlation_image(corr, valid_corr, aerial_image, power=1.0, corr_alpha=0.5, calibrated=False):
    best_angle = tf.argmax(tf.reduce_max(tf.reduce_max(corr, axis=-1), axis=-1), axis=0)
    corr = tf.gather(corr, best_angle, axis=0, batch_dims=0).numpy()
    if valid_corr is None:
        valid_corr = tf.ones(shape=tf.shape(corr), dtype="bool").numpy()
    else:
        valid_corr = tf.gather(valid_corr, best_angle, axis=0, batch_dims=0).numpy()

    offset = np.asarray(aerial_image.shape[:2]).astype("float32") / 2

    colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]

    only_valid_corr = corr.reshape([-1])[valid_corr.reshape([-1])]
    if not calibrated:
        corr = georeg.model.util.rescale(corr, min=np.amin(only_valid_corr), max=np.amax(only_valid_corr))
    corr = np.power(corr, power)
    padding = (np.asarray(corr.shape[:2]) - np.asarray(aerial_image.shape[:2])) // 2
    assert np.all(padding >= 0) or np.all(padding <= 0)
    corr_image = colormap[(corr * 255.0).astype("uint8")]
    # corr_image[np.where(np.logical_not(valid_corr))] = [0, 0, 0]
    if padding[0] > 0:
        corr_image = corr_image[padding[0]:corr_image.shape[0] - padding[0], padding[1]:corr_image.shape[1] - padding[1]]
        valid_corr = valid_corr[padding[0]:valid_corr.shape[0] - padding[0], padding[1]:valid_corr.shape[1] - padding[1]]
    else:
        padding = -padding
        corr_image = np.pad(corr_image, [[padding[0], padding[0]], [padding[1], padding[1]], [0, 0]], mode="constant", constant_values=0.0)
        valid_corr = np.pad(valid_corr, [[padding[0], padding[0]], [padding[1], padding[1]]], mode="constant", constant_values=False)
    valid_corr = valid_corr[::-1, ::-1] # TODO: assuming vehicle_pixel offset is always 0?
    corr_image = corr_image[::-1, ::-1, :] # TODO: assuming vehicle_pixel offset is always 0?
    corr_image = ((1 - corr_alpha) * aerial_image + corr_alpha * corr_image).astype("uint8")
    corr_image = np.where(valid_corr[:, :, np.newaxis], corr_image, aerial_image)

    return corr_image
