import numpy as np
import tensorflow as tf
import georeg, cv2, skimage.transform

def draw_aerial_image(tracking_state, tracker, probs=None, points_alpha=0.2, points_color=(255, 242, 0), imu=True, recalibrate=False):
    points_color = np.asarray(points_color)

    aerial_image = tracking_state.frame.aerial_frame.image.astype("uint8")
    # meters_per_pixel = tracking_state.frame.aligned_ground_frame.meters_per_pixel
    # translation = (tracking_state.prior.frame_to_world.inverse() * tracking_state.posterior.frame_to_world).translation / meters_per_pixel
    center = np.asarray(aerial_image.shape[:2]).astype("float32") / 2

    if not probs is None:
        probs = tf.reshape(probs, [len(tracker.predictor.angles), tracker.predictor.corr_shape[0], tracker.predictor.corr_shape[1]])
        best_angle_index = tf.reshape(probs, [tf.shape(probs)[0], -1]) # [angles, -1]
        best_angle_index = tf.reduce_max(best_angle_index, axis=-1) # [angles]
        best_angle_index = tf.argmax(best_angle_index, axis=0) # []
        probs = tf.gather(probs, best_angle_index, axis=0, batch_dims=0) # [dims...]
        probs = probs.numpy()[::-1, ::-1]
        probs = np.clip(probs, 0.0, 1.0)
        if recalibrate:
            probs = probs / np.amax(probs)

        padding = -(np.asarray(probs.shape[:2]) - np.asarray(aerial_image.shape[:2])) // 2
        assert np.all(padding >= 0)
        probs = np.pad(probs, [[padding[0], padding[0]], [padding[1], padding[1]]], mode="constant", constant_values=0.0)

        colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]
        corr_color = colormap[(probs * 255.0).astype("uint8")]

        t = 0.7
        aerial_image = (t * aerial_image + (1 - t) * corr_color).astype("uint8")

    if points_alpha > 0:
        bevpixels_to_aerialpixels = tracking_state.probability_distributions["posterior"].curr_to_ref

        aerial_pixels = bevpixels_to_aerialpixels(tracking_state.frame.ground_frame.bev_pixels.astype("float32"))
        aerial_pixels = aerial_pixels + np.asarray(tracking_state.frame.aerial_frame.image.shape[:2])[np.newaxis, :] / 2
        aerial_pixels = aerial_pixels.astype("int32")
        mask = np.all(np.logical_and(0 <= aerial_pixels, aerial_pixels < np.asarray(tracking_state.frame.aerial_frame.image.shape[:2])[np.newaxis, :]), axis=-1)
        aerial_pixels = aerial_pixels[mask]

        aerial_image[aerial_pixels[:, 0], aerial_pixels[:, 1]] = np.clip(points_alpha * points_color[np.newaxis, :] + (1.0 - points_alpha) * aerial_image[aerial_pixels[:, 0], aerial_pixels[:, 1]], 0.0, 255.0).astype("uint8")

    # if imu and not tracking_state.imu is None:
    #     # Draw imu measurements
    #     if "linear_acceleration" in tracking_state.imu:
    #         georeg.data.visualize.draw_line(
    #             aerial_image,
    #             np.asarray([center[0], aerial_image.shape[1] - 10]),
    #             np.asarray([center[0], aerial_image.shape[1] - 10]) - np.asarray([tracking_state.imu["linear_acceleration"][0] * 50, 0]),
    #             (255, 255, 255),
    #             thickness=2,
    #         )
    #     if "angular_velocity" in tracking_state.imu:
    #         georeg.data.visualize.draw_line(
    #             aerial_image,
    #             np.asarray([10, center[1]]),
    #             np.asarray([10, center[1]]) + np.asarray([0, -tracking_state.imu["angular_velocity"][2] * 150]),
    #             (255, 255, 255),
    #             thickness=2,
    #         )
    #
    #
    # size = 20
    # aerial_image[10:10 + size, aerial_image.shape[1] - 10 - size:aerial_image.shape[1] - 10] = (0, 255, 0) if tracking_state.valid_match else (255, 0, 0)

    # if True: # TODO: Draw covariance measurements
    #     center = np.asarray(aerial_image.shape[:2]).astype("float32") / 2
    #     scale = 5.0
    #     georeg.data.visualize.draw_line(
    #         aerial_image,
    #         center,
    #         center + scale * pos_uncertainty_vector1,
    #         (163, 73, 163),
    #         thickness=2,
    #     )
    #     georeg.data.visualize.draw_line(
    #         aerial_image,
    #         center,
    #         center + scale * pos_uncertainty_vector2,
    #         (163, 73, 163),
    #         thickness=2,
    #     )

    # georeg.data.visualize.draw_points(aerial_image, center, color=(0, 255, 0), radius=1.5)
    #
    # georeg.data.visualize.draw_points(aerial_image, center - translation, color=(255, 0, 255), radius=1.5)

    return aerial_image
