import tensorflow as tf
import numpy as np
import tfcv

def bilinear_sample(ground_pixels, image_features, batches_cameras):
    # ground_pixels: p 2
    # image_features: b c s... f
    # batches_cameras: p 2

    ground_pixels_lower = tf.stop_gradient(tf.math.maximum(tf.cast(tf.math.floor(ground_pixels), "int32"), 0))
    ground_pixels_upper = tf.stop_gradient(ground_pixels_lower + 1)
    ground_pixels_alpha = tf.clip_by_value(ground_pixels - tf.cast(ground_pixels_lower, ground_pixels.dtype), 0.0, 1.0)

    ground_pixels_corners = tf.stack([
        ground_pixels_lower,
        tf.stack([ground_pixels_lower[..., 0], ground_pixels_upper[..., 1]], axis=-1),
        tf.stack([ground_pixels_upper[..., 0], ground_pixels_lower[..., 1]], axis=-1),
        ground_pixels_upper,
    ], axis=0) # corners p 2

    features_at_corners = tf.gather_nd(
        image_features,
        tf.concat([
            tfcv.model.einops.apply("p f -> (corners p) f", batches_cameras, corners=4),
            tfcv.model.einops.apply("corners p f -> (corners p) f", ground_pixels_corners),
        ], axis=-1),
        batch_dims=0,
    )
    features_at_corners = tfcv.model.einops.apply("(corners p) f -> corners p f", features_at_corners, corners=4)

    features00 = features_at_corners[0] # p 2
    features01 = features_at_corners[1]
    features10 = features_at_corners[2]
    features11 = features_at_corners[3]

    alpha = ground_pixels_alpha # p 2
    # w00 = (1 - alpha[..., 0]) * (1 - alpha[..., 1])
    # w01 = (1 - alpha[..., 0]) * (    alpha[..., 1])
    # w10 = (    alpha[..., 0]) * (1 - alpha[..., 1])
    # w11 = (    alpha[..., 0]) * (    alpha[..., 1])
    # features = w00[..., tf.newaxis] * features00 + w01[..., tf.newaxis] * features01 + w10[..., tf.newaxis] * features10 + w11[..., tf.newaxis] * features11

    # features0 = features00 + alpha[..., 1:2] * (features01 - features00)
    # features1 = features10 + alpha[..., 1:2] * (features11 - features10)
    # features  = features1  + alpha[..., 0:1] * (features1  - features0 ) # p 2

    features0 = features00 * (1 - alpha[..., 1:2]) + features01 * alpha[..., 1:2]
    features1 = features10 * (1 - alpha[..., 1:2]) + features11 * alpha[..., 1:2]
    features  = features0  * (1 - alpha[..., 0:1]) + features1  * alpha[..., 0:1] # p 2

    return features
