import tensorflow as tf
import numpy as np
import os, tfcv, georeg, math, cv2, imageio
from . import util

class GroundAttn:
    def __init__(self, model_constants, model_params):
        self.model_constants = model_constants
        self.model_params = model_params

        self.filters_v = [int(s) for s in os.environ["GROUND_ATTN_FILTERS_V"].split(",")]
        model_params["outer"]["filters-v"] = self.filters_v

        self.last_bev_mask = None

    def get_modules(self, model_constants, loss_params):
        return []

    def __call__(self, bev, ground_featuremaps, model_input, stride, iteration, name, config): # TODO: move ground features computation into this class?
        ground_attn_stride = int(self.model_constants.ground_attn_strides[iteration])
        ground_attn_bev_shape = self.model_constants.bev_shapes[-1] // ground_attn_stride

        x = tfcv.model.util.conv(ground_featuremaps[-1], filters=self.filters_v[iteration], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "image-features"), config=config)

        factor = tf.cast(tf.shape(x)[-3:-1], "float32") / tf.cast(tf.shape(model_input.ground_images)[-3:-1], "float32")
        ground_pixels = tfcv.model.einops.apply("b c p i0 -> (b c p) i0", model_input.ground_pixels) * factor[tf.newaxis, :]
        ground_pixels_mask = tfcv.model.einops.apply("b c p -> (b c p)", model_input.ground_pixels_mask)
        bev_pixels = tfcv.model.einops.apply("b p i0 -> (b c p) i0", model_input.bev_pixels // ground_attn_stride, c=model_input.max_cameras) + tf.cast(ground_attn_bev_shape[tf.newaxis, :], "float32") / 2
        batches_cameras = tf.stack([
            tfcv.model.einops.apply("b -> (b c p)", tf.range(model_input.batches), c=model_input.max_cameras, p=model_input.points),
            tfcv.model.einops.apply("c -> (b c p)", tf.range(model_input.max_cameras), b=model_input.batches, p=model_input.points),
        ], axis=-1)

        ground_pixels = tf.boolean_mask(ground_pixels, ground_pixels_mask, axis=0)
        bev_pixels = tf.boolean_mask(bev_pixels, ground_pixels_mask, axis=0)
        batches_cameras = tf.boolean_mask(batches_cameras, ground_pixels_mask, axis=0)

        ground_points_features = util.bilinear_sample(ground_pixels, x, batches_cameras)


        # Scatter on bev
        bev, bev_mask = georeg.model.util.project.scatter_features_on_image_mean_unbatched(
            ground_points_features,
            points=tf.concat([
                batches_cameras[..., :1],
                tf.cast(bev_pixels, "int32"),
            ], axis=-1),
            shape=tf.concat([
                [model_input.batches],
                ground_attn_bev_shape,
            ], axis=0),
            return_mask=True,
        )


        # Resize for output
        bev = georeg.model.util.backbone.resize(bev, ground_attn_stride, stride, name=tfcv.model.util.join(name, "bev-outof-ground-attn"), config=config) # b s... f
        if stride != ground_attn_stride:
            bev_shape = tf.shape(bev_mask)[-2:] * ground_attn_stride // stride
            bev_mask = tf.math.greater(tfcv.model.util.resize(tf.where(bev_mask[..., tf.newaxis], 1.0, 0.0), bev_shape, method="bilinear", config=config)[..., 0], 0.0)

        # Add circular mask
        p = georeg.model.util.project.generate_bev_pixels(model_input.batches, tf.shape(bev_mask)[-2:])
        p = tf.cast(p, "float32")
        radius = tf.cast(tf.shape(bev_mask)[-2:] / 2, p.dtype)[tf.newaxis, tf.newaxis, tf.newaxis, :]
        p = (p - radius) / radius
        bev_mask = tf.math.logical_and(
            bev_mask,
            tf.reduce_sum(tf.math.square(p), axis=-1) <= 1.0,
        )
        bev = tf.where(bev_mask[..., tf.newaxis], bev, 0.0)
        self.last_bev_mask = bev_mask

        return bev
