import tensorflow as tf
import tfcv, georeg

def scatter_features_on_image_qkv_unbatched(query, key, value, points, shape, iteration, logits_forward=None):
    # query: h p f
    # key: h p f
    # value: h p f
    # points: p i

    heads = tf.shape(value)[0]
    p = tf.shape(points)[0]

    query *= float(query.shape[-1]) ** -0.5

    query = tfcv.model.einops.apply("head p f -> (head p) f", query)
    key = tfcv.model.einops.apply("head p f -> (head p) f", key)
    value = tfcv.model.einops.apply("head p f -> (head p) f", value)
    points = tf.concat([
        tfcv.model.einops.apply("p i -> (head p) i", points, head=heads),
        tfcv.model.einops.apply("head -> (head p) 1", tf.range(heads), p=p),
    ], axis=-1) # (h p) i
    shape = tf.concat([shape, [heads]], axis=0)

    # Matmul
    weights = tfcv.model.einops.apply("points f_qk, points f_qk -> points 1", query, key)

    if not logits_forward is None:
        weights = logits_forward(weights)

    # Exp
    numerial_offset = tf.zeros(tf.concat([shape, [1]], axis=0), weights.dtype) + tf.reduce_min(weights)
    numerial_offset = tf.tensor_scatter_nd_max(numerial_offset, points, weights)
    numerial_offset = tf.gather_nd(numerial_offset, points) # (head p) 1
    exp_weights = tf.math.exp(weights - numerial_offset)

    # Reduce sum
    sum_exp_weights = tf.scatter_nd(points, exp_weights, tf.concat([shape, [1]], axis=0)) # d... head 1
    sum_exp_weights = tf.gather_nd(sum_exp_weights, points)

    # Divide
    weights = tf.math.divide_no_nan(exp_weights, sum_exp_weights) # (head p) 1
    weights = tfcv.model.einops.apply("(h p) 1 -> h p", weights, h=heads, p=p)
    weights = tfcv.model.util.set_name(weights, f"pillar-weights-{iteration + 1}")
    weights = tfcv.model.einops.apply("h p -> (h p) 1", weights)

    # Apply attention weights
    value = tf.scatter_nd(points, weights * value, tf.concat([shape, [value.shape[-1]]], axis=0)) # d... head f

    return value

def scatter_features_on_image_wv_unbatched(weights, value, points, shape, iteration):
    # weights: h p
    # value: h p f
    # points: p i

    h = tf.shape(value)[0]
    p = tf.shape(points)[0]

    weights = tfcv.model.einops.apply("h p -> (h p) 1", weights)
    value = tfcv.model.einops.apply("h p f -> (h p) f", value)
    points = tf.concat([
        tfcv.model.einops.apply("p i -> (h p) i", points, h=h),
        tfcv.model.einops.apply("h -> (h p) 1", tf.range(h), p=p),
    ], axis=-1) # (h p) i
    shape = tf.concat([shape, [h]], axis=0)

    # Exp
    numerial_offset = tf.zeros(tf.concat([shape, [1]], axis=0), weights.dtype) + tf.reduce_min(weights)
    numerial_offset = tf.tensor_scatter_nd_max(numerial_offset, points, weights)
    numerial_offset = tf.gather_nd(numerial_offset, points) # (h p) 1
    exp_weights = tf.math.exp(weights - numerial_offset)

    # Reduce sum
    sum_exp_weights = tf.scatter_nd(points, exp_weights, tf.concat([shape, [1]], axis=0)) # d... h 1
    sum_exp_weights = tf.gather_nd(sum_exp_weights, points)

    # Divide
    weights = tf.math.divide_no_nan(exp_weights, sum_exp_weights) # (h p) 1
    weights = tfcv.model.einops.apply("(h p) 1 -> h p", weights, h=h, p=p)
    weights = tfcv.model.util.set_name(weights, f"pillar-weights-{iteration + 1}")
    weights = tfcv.model.einops.apply("h p -> (h p) 1", weights)

    # Apply attention weights
    value = tf.scatter_nd(points, weights * value, tf.concat([shape, [value.shape[-1]]], axis=0)) # d... h f

    return value

def scatter_features_on_image_mean_unbatched(value, points, shape, return_mask=False):
    # value: p f
    # points: p i

    filters = value.shape[-1]

    value = tf.concat([
        value,
        tf.ones_like(value[..., :1]),
    ], axis=-1)

    shape = tf.cast(tf.concat([shape, [value.shape[-1]]], axis=0), "int32")

    value = tf.scatter_nd(points, value, shape) # d... f

    if return_mask:
        mask = value[..., -1] > 0

    value = tf.math.divide_no_nan(value[..., :-1], value[..., -1:])

    value = tf.ensure_shape(value, [None] * (len(value.shape) - 1) + [filters])

    return (value, mask) if return_mask else value



def generate_bev_pixels(batches, shape):
    xs = tf.range(shape[0], dtype="int32")
    ys = tf.range(shape[1], dtype="int32")
    shape = tf.stack([batches, shape[0], shape[1]])
    bev_pixels = tf.stack([
        tf.broadcast_to(xs[tf.newaxis, :, tf.newaxis], shape),
        tf.broadcast_to(ys[tf.newaxis, tf.newaxis, :], shape),
    ], axis=-1)

    return bev_pixels # b s... f

def bevpixels_to_aerialpixels_onimage(bev_shape, aerial_shape, loss_input, bev_pixels, model_constants):
    pixels = tfcv.model.einops.apply("b s... f -> b (s...) f", bev_pixels)
    pixels = tf.cast(pixels, "float32")

    bevpixels_to_aerialpixels = loss_input.get_bevpixels_to_aerialpixels_with_shape(
        bev_shape=bev_shape,
        aerial_shape=aerial_shape,
        model_constants=model_constants,
    )

    pixels = pixels - tf.cast(bev_shape, pixels.dtype)[tf.newaxis, tf.newaxis, :] / 2.0
    pixels = tfcv.model.einops.apply("b i0 i1, b p i1 -> b p i0", bevpixels_to_aerialpixels[:, :2, :2], tf.cast(pixels, bevpixels_to_aerialpixels.dtype)) + bevpixels_to_aerialpixels[:, :2, 2][:, tf.newaxis, :]
    pixels = pixels + tf.cast(aerial_shape, pixels.dtype)[tf.newaxis, tf.newaxis, :] / 2.0

    pixels_mask = tf.reduce_all(tf.math.logical_and(
        pixels >= 0,
        pixels <= tf.cast(aerial_shape - 1, pixels.dtype)[tf.newaxis, tf.newaxis, :]
    ), axis=-1)

    return pixels, pixels_mask

def aerialpixels_to_bevpixels_onimage(aerial_shape, bev_shape, loss_input, aerial_pixels, model_constants):
    pixels = tfcv.model.einops.apply("b s... f -> b (s...) f", aerial_pixels)
    pixels = tf.cast(pixels, "float32")

    aerialpixels_to_bevpixels = loss_input.get_bevpixels_to_aerialpixels_with_shape(
        bev_shape=bev_shape,
        aerial_shape=aerial_shape,
        model_constants=model_constants,
    )
    aerialpixels_to_bevpixels = tf.linalg.inv(aerialpixels_to_bevpixels)

    pixels = pixels - tf.cast(aerial_shape, pixels.dtype)[tf.newaxis, tf.newaxis, :] / 2.0
    pixels = tfcv.model.einops.apply("b i0 i1, b p i1 -> b p i0", aerialpixels_to_bevpixels[:, :2, :2], tf.cast(pixels, aerialpixels_to_bevpixels.dtype)) + aerialpixels_to_bevpixels[:, :2, 2][:, tf.newaxis, :]
    pixels = pixels + tf.cast(bev_shape, pixels.dtype)[tf.newaxis, tf.newaxis, :] / 2.0

    pixels_mask = tf.reduce_all(tf.math.logical_and(
        pixels >= 0,
        pixels <= tf.cast(bev_shape - 1, pixels.dtype)[tf.newaxis, tf.newaxis, :]
    ), axis=-1)

    return pixels, pixels_mask

def bevpixels_to_groundpixels_onimage(bev_shape, model_constants, model_input, ego_to_fixedego, points):
    points = tf.cast(points, "float32")
    points = tfcv.model.einops.apply("b s... h f -> b (s... h) f", points)

    factor = tf.cast(model_constants.bev_shapes[-1][0], "float32") / tf.cast(bev_shape[0], "float32")
    points = points * factor

    offset = tf.cast(model_constants.bev_shapes[-1], points.dtype) / 2
    points_2d = points[..., :2] - tf.cast(offset[tf.newaxis, tf.newaxis, :], points.dtype)
    heights = points[..., 2:]

    points = tf.concat([
        tfcv.model.einops.apply("b p i0 -> b p i0", points_2d),
        tfcv.model.einops.apply("b p 1 -> b p 1", heights),
    ], axis=-1)

    ego_to_cam = model_input.ground_ego_to_cam
    pixels_to_ego = tf.linalg.inv(model_input.ground_ego_to_pixels)
    pixels_to_ego = tfcv.model.einops.apply("b f0 f1 -> b c f0 f1", pixels_to_ego, c=model_input.max_cameras)

    def asd(ground_intr):
        return georeg.model.util.pad_matrix(ground_intr, pad=1)
    cam_to_screen = tf.keras.layers.Lambda(asd)(model_input.ground_intr)

    if not ego_to_fixedego is None:
        ego_to_fixedego = tfcv.model.einops.apply("b f0 f1 -> b c f0 f1", ego_to_fixedego, c=model_input.max_cameras)
        transform = cam_to_screen @ ego_to_cam @ ego_to_fixedego @ pixels_to_ego
    else:
        transform = cam_to_screen @ ego_to_cam @ pixels_to_ego
    transform = transform[..., :3, :]

    points = tfcv.model.einops.apply("b c i0 i1, b points i1 -> b c points i0", transform[..., :, :3], points, i0=3) + transform[..., :, 3][..., tf.newaxis, :]

    depths = points[..., 2]
    points = points[..., :2][..., ::-1] / depths[..., tf.newaxis]

    mask = tf.math.reduce_all(tf.math.logical_and(
        0.0 <= points,
        points <= tf.cast(model_input.ground_images_shape[:, :, tf.newaxis, :] - 1, points.dtype),
    ), axis=-1)
    # mask = tf.math.logical_and(
    #     mask,
    #     tf.math.reduce_all(tf.math.logical_and(
    #         0.0 <= points,
    #         points <= tf.cast(model_input.ground_images_shape[:, :, tf.newaxis, :] - 1, points.dtype),
    #     ), axis=-1),
    # )
    mask = tf.math.logical_and(
        mask,
        depths > 0,
    )
    points = tf.where(mask[..., tf.newaxis], points, 0)
    mask = tf.math.logical_and(
        mask,
        tf.gather_nd(model_input.ground_images_mask[..., tf.newaxis], tf.cast(points, "int32"), batch_dims=2)[..., 0],
    )

    return points, depths, mask
