import tensorflow as tf
import tensorflow_addons as tfa
import tfcv, georeg, math, cosy, os, copy, imageio, cv2, re, types
import numpy as np
from functools import partial
from distinctipy import distinctipy
import georegdata as grd

class SaveCorrRgb:
    def __init__(self, model_constants, loss_params):
        self.model_constants = model_constants

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {}

        layer_names = [l.name for l in model.layers]

        result[f"initial-bev-mask"] = model.get_layer(f"initial-bev-mask").output

        for i in range(1, self.model_constants.blocks + 1):
            if f"correlation-logits-{i}" in layer_names:
                result[f"correlation-logits-{i}"] = model.get_layer(f"correlation-logits-{i}").output
        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        bev_mask = model_output[f"initial-bev-mask"]
        def save(corr, valid_corr, name):
            # corr: b h a s...

            b = tf.shape(corr)[0]
            h = tf.shape(corr)[1]
            a = tf.shape(corr)[2]

            # Unrotate
            corr = tfcv.model.einops.apply("b h a s... -> (b h a) s...", corr)

            angles = tfcv.model.einops.apply("b a -> (b h a)", -model_input.angles, h=h)
            def rotate(input):
                image, angles = input
                return tfa.image.rotate(image, angles, interpolation="bilinear", fill_mode="constant", fill_value=0.0)
            corr = tf.keras.layers.Lambda(rotate)([corr[..., tf.newaxis], angles])[..., 0] # (b h a) s...

            corr = tfcv.model.einops.apply("(b h a) s... -> b h a s...", corr, h=h, a=a)


            # Resize
            corr = tfcv.model.einops.apply("b h a s... -> (b h a) s... 1", corr)
            corr = tfcv.model.util.resize(corr, model_input.corr_shape, method="bilinear")
            corr = tfcv.model.einops.apply("(b h a) s... 1 -> b h a s...", corr, b=b, h=h, a=a)

            valid_corr = tfcv.model.einops.apply("b h a s... -> (b h a) s... 1", valid_corr)
            valid_corr = tf.math.greater(tfcv.model.util.resize(tf.where(valid_corr, 1.0, 0.0), model_input.corr_shape, method="bilinear"), 0.0)
            valid_corr = tfcv.model.einops.apply("(b h a) s... 1 -> b h a s...", valid_corr, b=b, h=h, a=a)

            corr = georeg.model.correlation.math.softmax(corr, valid_corr)

            # Resize and rotate aerial image
            aerial_image = model_input.aerial_image
            aerial_image = tfcv.model.util.resize(aerial_image, tf.shape(aerial_image)[-3:-1] // self.model_constants.aerial_stride, method="bilinear")
            aerial_image = tfcv.model.einops.apply("b s... f -> (b h a) s... f", aerial_image, a=tf.shape(model_input.angles)[1], h=h)

            angles = tfcv.model.einops.apply("b a -> (b h a)", -model_input.angles, h=h)

            def rotate(input):
                image, angles = input
                return tfa.image.rotate(image, angles, interpolation="bilinear", fill_mode="constant", fill_value=0.0)
            aerial_image = tf.keras.layers.Lambda(rotate)([aerial_image, angles]) # (b h a) s... f


            # Compute correlation
            corr = tfcv.model.einops.apply("b h a s... -> (b h a) s...", corr)

            pad_size = tf.shape(aerial_image)[-3:-1] - tf.shape(corr)[-2:]
            front = pad_size // 2
            back = pad_size - front
            paddings = [[0, 0], [front[0], back[0]], [front[1], back[1]]]
            corr = tf.pad(corr, paddings, mode="CONSTANT", constant_values=0.0)

            corr = tfcv.model.einops.apply("(b h a) s... -> b h a s... f", corr, f=aerial_image.shape[-1], a=tf.shape(model_input.angles)[1], h=h)
            aerial_image = tfcv.model.einops.apply("(b h a) s... f -> b h a s... f", aerial_image, a=tf.shape(model_input.angles)[1], h=h)

            fused_aerial_image = tf.keras.layers.Lambda(lambda x: georeg.model.correlation.math.cross_correlate_2d(x[0], x[1], method="fft"))([aerial_image, corr[:, :, :, ::-1, ::-1, :]])
            fused_aerial_image = tfcv.model.einops.apply("b h a s... f -> b h s... f", fused_aerial_image, reduction="sum")

            # Crop to bev_shape
            pad_size = tf.cast(tf.shape(fused_aerial_image)[-3:-1], "int64") - tf.cast(self.model_constants.bev_shapes[-1], "int64")
            front = pad_size // 2
            back = tf.cast(tf.shape(fused_aerial_image)[-3:-1], "int64") - (pad_size - front)
            fused_aerial_image = fused_aerial_image[:, :, front[0]:back[0], front[1]:back[1], :]

            bev_mask_resized = tf.math.greater(tfcv.model.util.resize(tf.where(bev_mask[..., tf.newaxis], 1.0, 0.0), tf.shape(fused_aerial_image)[-3:-1], method="bilinear")[..., 0], 0.0)
            fused_aerial_image = tf.where(bev_mask_resized[:, tf.newaxis, :, :, tf.newaxis], fused_aerial_image, 0.0)

            for b in range(model_input.batches):
                for h in range(tf.shape(corr)[1].numpy()):
                    imageio.imwrite(os.path.join(path, f"{names[b]}-fusedcorr-{name}-h{h}.jpg"), fused_aerial_image[b, h].numpy().astype("uint8"))

        for i in range(1, self.model_constants.blocks + 1):
            if f"correlation-logits-{i}" in model_output:
                save(model_output[f"correlation-logits-{i}"], model_output[f"correlation-mask-{i}"], name=f"block{i}")

class DrawCorrelationImage:
    def __init__(self, model_constants):
        self.model_constants = model_constants

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {}

        layer_names = [l.name for l in model.layers]

        for i in range(1, self.model_constants.blocks + 1):
            if f"correlation-logits-{i}" in layer_names:
                result[f"correlation-logits-{i}"] = model.get_layer(f"correlation-logits-{i}").output
                result[f"correlation-mask-{i}"] = model.get_layer(f"correlation-mask-{i}").output
        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        def save(corr, valid_corr, name):
            # corr: b h a s...

            b = tf.shape(corr)[0]
            h = tf.shape(corr)[1]
            a = tf.shape(corr)[2]

            corr = tfcv.model.einops.apply("b h a s... -> (b h a) s... 1", corr)
            corr = tfcv.model.util.resize(corr, model_input.corr_shape, method="bilinear")
            corr = tfcv.model.einops.apply("(b h a) s... 1 -> b h a s...", corr, b=b, h=h, a=a)

            valid_corr = tfcv.model.einops.apply("b h a s... -> (b h a) s... 1", valid_corr)
            valid_corr = tf.math.greater(tfcv.model.util.resize(tf.where(valid_corr, 1.0, 0.0), model_input.corr_shape, method="bilinear"), 0.0)
            valid_corr = tfcv.model.einops.apply("(b h a) s... 1 -> b h a s...", valid_corr, b=b, h=h, a=a)

            # Softmax
            corr = tfcv.model.einops.apply("b h a s... -> b h (a s...)", corr, h=h, a=a)
            valid_corr = tfcv.model.einops.apply("b h a s... -> b h (a s...)", valid_corr, h=h, a=a)
            corr = tf.where(valid_corr, corr, 0.0)

            numerical_offset = tf.reduce_max(tf.where(valid_corr, corr, tf.reduce_min(corr, axis=-1, keepdims=True)), axis=-1, keepdims=True)
            corr = corr - numerical_offset
            corr = tf.math.exp(corr)
            corr = tf.where(valid_corr, corr, 0.0)
            corr = tf.math.divide_no_nan(corr, tf.reduce_sum(corr, axis=-1, keepdims=True))

            corr = tfcv.model.einops.apply("b h (a s...) -> b h a s...", corr, h=h, a=a, s=model_input.corr_shape)
            valid_corr = tfcv.model.einops.apply("b h (a s...) -> b h a s...", valid_corr, h=h, a=a, s=model_input.corr_shape)


            aerial_image = model_input.aerial_image
            aerial_image = tfcv.model.util.resize(aerial_image, tf.shape(aerial_image)[-3:-1] // self.model_constants.aerial_stride, method="bilinear")

            for b in range(model_input.batches):
                for h in range(tf.shape(corr)[1].numpy()):
                    for pow in [1.0]: # , 10.0, 100.0
                        imageio.imwrite(os.path.join(path, f"{names[b]}-corr-pow{pow:.1f}-{name}-h{h}.jpg"), georeg.model.correlation.visualize.draw_correlation_image(corr[b, h].numpy(), valid_corr[b, h].numpy(), aerial_image[b].numpy(), power=pow))

        for i in range(1, self.model_constants.blocks + 1):
            if f"correlation-logits-{i}" in model_output:
                save(model_output[f"correlation-logits-{i}"], model_output[f"correlation-mask-{i}"], name=f"block{i}")

# class SoftmaxUnderflowMetric:
#     def __init__(self, model_constants):
#         self.name = "softmax-underflow-metric"
#         self.model_constants = model_constants
#
#     def get_output(self, model, model_input, loss_input, config, model_params):
#         result = {}
#
#         layer_names = [l.name for l in model.layers]
#
#         for i in range(1, self.model_constants.blocks + 1):
#             if f"correlation-unrotated-{i}" in layer_names:
#                 result[f"correlation-unrotated-{i}"] = model.get_layer(f"correlation-unrotated-{i}").output
#                 result[f"correlation-mask-{i}"] = model.get_layer(f"correlation-mask-{i}").output
#             if f"aerial-attn-bev-{i}" in layer_names:
#                 result[f"aerial-attn-bev-{i}"] = model.get_layer(f"aerial-attn-bev-{i}").output
#         return result
#
#     def loss_fn(self, model_output, model_input, loss_input, schedule_factor, metrics):
#         def add(corr, valid_corr, i):
#             corr = tfcv.model.einops.apply("b a s... -> b (a s...)", corr)
#             valid_corr = tfcv.model.einops.apply("b a s... -> b (a s...)", valid_corr)
#
#             metrics[f"softmax-underflow-ratio-{i + 1}"] = tf.math.divide_no_nan(
#                 tf.cast(tf.math.count_nonzero(tf.math.logical_and(corr == 0, valid_corr), axis=1), "float32"),
#                 tf.cast(tf.math.count_nonzero(valid_corr, axis=1), "float32"),
#             )
#
#         for i in range(1, self.model_constants.blocks + 1):
#             if f"correlation-unrotated-{i}" in model_output:
#                 add(model_output[f"correlation-unrotated-{i}"], model_output[f"correlation-mask-unrotated-{i}"], i - 1)
#
#         # def add(aerial_attn_output, i):
#         #     # aerial_attn_output: b s... (h f)
#         #
#         #     metrics[f"mdA1-{i + 1}"] = tf.math.divide_no_nan(
#         #         tf.cast(tfcv.model.einops.apply("b s... f -> b", aerial_attn_output == 0, reduction="count_nonzero"), "float32"),
#         #         tf.cast(tf.math.reduce_prod(tf.shape(aerial_attn_output)[1:]), "float32"),
#         #     )
#         #
#         #     metrics[f"mdA2-{i + 1}"] = tfcv.model.einops.apply("b s... f -> b", aerial_attn_output, reduction="sum")
#         #
#         # for i in range(1, self.model_constants.blocks + 1):
#         #     if f"aerial-attn-bev-{i}" in model_output:
#         #         add(model_output[f"aerial-attn-bev-{i}"], i - 1)
#
#         return None

# TODO: requires cameras masking for batchsize > 1
class PointPillar:
    def __init__(self, model_constants, model_params):
        self.name = "point-pillar"
        model_params["type"] = self.name
        model_params = model_params[self.name]
        self.model_constants = model_constants

        self.ground_attn = georeg.model.ground_attn.pointpillar.GroundAttn(
            model_constants=self.model_constants,
            model_params=model_params,
        )

    def get_modules(self, model_constants, loss_params):
        return [m for m in [
            SaveCorrRgb(model_constants, loss_params),
            DrawCorrelationImage(model_constants),
            # SoftmaxUnderflowMetric(model_constants),
        ] + self.ground_attn.get_modules(model_constants, loss_params) if not m is None]

    def predict(self, model_input, model_params, config):
        # ########### Predict ground and aerial image features ###########
        backbone = str(os.environ["BACKBONE"]) if "BACKBONE" in os.environ else "facebookresearch.convnext_tiny_imagenet1k_224"
        model_params["backbone"] = backbone

        def context_pool(x, name):
            x_orig = x
            x = tfcv.model.einops.apply("b s... f -> b f", x, reduction="mean")
            x = tfcv.model.util.norm(x, config=config)
            x = tfcv.model.util.conv_act(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, config=config)
            x = tfcv.model.util.conv(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, config=config)
            x = tfcv.model.einops.apply("b f -> b s... f", x, output_shape=tf.shape(x_orig))
            x = tfcv.model.util.conv(x_orig, filters=x_orig.shape[-1], kernel_size=1, stride=1, config=config) + x
            if backbone_includes_norm:
                x = tfcv.model.util.norm(x, config=config)
            return x

        def replace_stochastic_depth(x):
            def block(x, layer, layer_weights):
                # print(f"Replacing {layer.name} with input {x}")
                return x
            def pred(layer):
                return isinstance(layer, tfcv.model.stochasticdepth.DropPath)
            x = tfcv.model.graph.replace(x, pred=pred, block=block)
            assert len(tfcv.model.graph.get_all(x, pred=lambda layer: isinstance(layer, tfcv.model.stochasticdepth.DropPath))) == 0
            return x


        ground_images_batch_camera = tf.concat([
            tfcv.model.einops.apply("b -> b c 1", tf.range(tf.shape(model_input.ground_images)[0]), c=tf.shape(model_input.ground_images)[1]),
            tfcv.model.einops.apply("c -> b c 1", tf.range(tf.shape(model_input.ground_images)[1]), b=tf.shape(model_input.ground_images)[0]),
        ], axis=-1) # b c 2
        ground_images_mask = tf.math.less(ground_images_batch_camera[..., 1], model_input.ground_images_num[:, tf.newaxis]) # b c


        no_ground_images = int(os.environ["NO_GROUND_IMAGES"]) == 1 if "NO_GROUND_IMAGES" in os.environ else False
        model_params["no_ground_images"] = no_ground_images

        ground_images = (model_input.ground_images * 0) if no_ground_images else model_input.ground_images
        ground_images = tfcv.model.einops.apply("b c s... f -> (b c) s... f", ground_images)
        ground_images_batch_camera = tfcv.model.einops.apply("b c f -> (b c) f", ground_images_batch_camera)
        ground_images_shape = tfcv.model.einops.apply("b c f -> (b c) f", model_input.ground_images_shape)
        ground_images_mask = tfcv.model.einops.apply("b c -> (b c)", ground_images_mask)

        ground_images = tf.boolean_mask(ground_images, ground_images_mask, axis=0) # bc s... f
        ground_images_batch_camera = tf.boolean_mask(ground_images_batch_camera, ground_images_mask, axis=0) # bc 2


        b0, b1 = backbone.split(".")
        pretrained = b0 in vars(tfcv.model.pretrained)
        if pretrained:
            builder = vars(vars(tfcv.model.pretrained)[b0])[b1]
        else:
            builder = types.SimpleNamespace(
                create=lambda x, name: vars(vars(tfcv.model)[b0])[b1](x, name=name, config=config),
                preprocess=tfcv.model.pretrained.timm.convnext_nano_imagenet1k_224.preprocess,
            )
        if "convnext" in backbone:
            backbone_includes_norm = False
        elif "segformer" in backbone:
            backbone_includes_norm = True
        else:
            assert False
        preprocess_ground = builder.preprocess
        preprocess_aerial = builder.preprocess

        aerial_feature_map = builder.create(model_input.aerial_image, name="aerial-backbone")
        aerial_feature_map = replace_stochastic_depth(aerial_feature_map)

        ground_context = str(os.environ["PP_GC_TYPE"]) if "PP_GC_TYPE" in os.environ else "global-mean"
        model_params["ground"]["context"]["type"] = ground_context
        assert ground_context in ["none", "global-mean"]

        input = tf.keras.layers.Input((None, None, 3))
        x = builder.create(input, name="ground-backbone")
        x = tfcv.model.graph.get_unique(x, pred=lambda layer: layer.name.endswith("/block4"))
        x = replace_stochastic_depth(x)

        if ground_context == "global-mean":
            x = context_pool(x, name=tfcv.model.util.join("ground", "context"))
        xs = [tfcv.model.graph.get_unique(x, pred=lambda layer: layer.name.endswith(f"block{i}")) for i in [1, 2, 3]] + [x]
        ground_enc_model = tf.keras.Model(inputs=[input], outputs=xs)

        ground_decoder_filters = int(os.environ["GROUND_DECODER_FILTERS"]) if "GROUND_DECODER_FILTERS" in os.environ else 512
        model_params["ground"]["decoder"]["filters"] = ground_decoder_filters
        ground_decoder_stride = int(os.environ["GROUND_DECODER_STRIDE"]) if "GROUND_DECODER_STRIDE" in os.environ else 1
        model_params["ground"]["decoder"]["stride"] = ground_decoder_stride

        inputs = [tf.keras.layers.Input((None, None, x.shape[-1])) for x in ground_enc_model.outputs]
        xs = inputs
        x = georeg.model.util.backbone.decode(
            xs,
            filters=[ground_decoder_filters],
            norm=not backbone_includes_norm,
            image=None,
            strides=[ground_decoder_stride],
            name="ground",
            config=config,
        )[0]
        ground_dec_model = tf.keras.Model(inputs=inputs, outputs=[x])

        levels = 4

        class Layer(tf.keras.layers.Layer):
            def __init__(self, ground_enc_model, ground_dec_model, **kwargs):
                super().__init__(**kwargs)
                self.ground_enc_model = ground_enc_model
                self.ground_dec_model = ground_dec_model

            def call(self, ground_images, ground_images_shape, orig_ground_images_shape):
                def map_fn(bc):
                    image = ground_images[bc]
                    cropped_shape = ground_images_shape[bc]
                    image = image[:cropped_shape[0], :cropped_shape[1]]

                    ground_featuremaps = self.ground_enc_model(image[tf.newaxis]) # level 1 s... f
                    ground_image_features = self.ground_dec_model(ground_featuremaps)[0] # s... f
                    ground_featuremaps = [g[0] for g in ground_featuremaps] # level s... f

                    def pad(x, stride):
                        paddings = [
                            [0, orig_ground_images_shape[0] // stride - tf.shape(x)[0]],
                            [0, orig_ground_images_shape[1] // stride - tf.shape(x)[1]],
                            [0, 0],
                        ]
                        return tf.pad(x, paddings, mode="CONSTANT", constant_values=0)

                    # Crop and pad image features
                    cropped_shape = cropped_shape // ground_decoder_stride
                    ground_image_features = ground_image_features[:cropped_shape[0], :cropped_shape[1]]
                    ground_image_features = pad(ground_image_features, ground_decoder_stride)

                    # Pad block features
                    ground_featuremaps = tuple(pad(ground_featuremaps[l], 4 * (2 ** l)) for l in range(levels))

                    return ground_image_features, ground_featuremaps

                ground_image_features, ground_featuremaps = tf.map_fn(
                    fn=map_fn,
                    elems=tf.range(tf.shape(ground_images)[0]),
                    fn_output_signature=(
                        tf.TensorSpec([None, None, self.ground_dec_model.outputs[0].shape[-1]], dtype="float32"),
                        tuple(tf.TensorSpec([None, None, self.ground_enc_model.outputs[l].shape[-1]], dtype="float32") for l in range(levels)),
                    ),
                )

                return ground_image_features, ground_featuremaps
        ground_image_features, ground_featuremaps = Layer(ground_enc_model, ground_dec_model)(ground_images, ground_images_shape, tf.shape(model_input.ground_images)[2:4])

        def scatter(x):
            return tf.scatter_nd(
                indices=ground_images_batch_camera,
                updates=x,
                shape=tf.concat([tf.shape(model_input.ground_images)[:2], tf.shape(x)[1:]], axis=0),
            )
        ground_image_features = scatter(ground_image_features)
        ground_featuremaps = [scatter(ground_featuremaps[l]) for l in range(levels)]

        # model_constants = self.model_constants
        # class Test(tf.keras.layers.Layer):
        #     def __init__(self):
        #         super(Test, self).__init__()
        #
        #     def call(self, inputs, training):
        #         x, model_input_training = inputs
        #         tf.debugging.assert_equal(training, model_input_training, "Leaking")
        #         return x
        # ground_image_features = Test()([ground_image_features, model_input.training_mask])

        # def asd(inputs):
        #     ground_images, ground_image_features = inputs
        #     tf.debugging.assert_equal(tf.shape(ground_images)[:-1], tf.shape(ground_image_features)[:-1], "ASD")
        #     return ground_images, ground_image_features
        # ground_images, ground_image_features = tf.keras.layers.Lambda(asd)([model_input.ground_images, ground_image_features])

        ground_image_features = tf.keras.layers.Lambda(lambda x: tf.debugging.assert_all_finite(x, f"Got non-finite ground_image_features"))(ground_image_features)
        ground_image_features = tfcv.model.util.set_name(ground_image_features, "ground/image_features")

        ground_featuremaps = [tfcv.model.util.set_name(ground_featuremaps[l], f"ground/featuremap{l + 1}") for l in range(levels)]


        aerial_corr = georeg.model.aerial_corr.TransformerBlock(
            model_input=model_input,
            model_constants=self.model_constants,
            model_params=model_params["aerial-attn"],
        )


        aerial_feature_map = tfcv.model.graph.get_unique(aerial_feature_map, pred=lambda layer: layer.name.endswith("/block4"))
        model_params["aerial"]["context"]["type"] = "global-mean"
        aerial_feature_map = context_pool(aerial_feature_map, name=tfcv.model.util.join("aerial", "context"))


        aerial_resnet_shortcut = int(os.environ["AERIAL_RESNET_SHORTCUT"]) == 1 if "AERIAL_RESNET_SHORTCUT" in os.environ else True
        aerial_features = georeg.model.util.backbone.decode(
            aerial_feature_map,
            filters=[int(os.environ["AERIAL_DECODER_FILTERS"])],
            norm=not backbone_includes_norm,
            image=model_input.aerial_image if aerial_resnet_shortcut else None,
            strides=[1],
            name="aerial",
            config=config,
        )[0]















        # # Distance positional encoding
        # gpoints_depths_posenc = gpoints_depths[tf.newaxis, :, tf.newaxis]
        # # gpoints_depths_posenc = tfcv.model.util.conv_act(gpoints_depths_posenc, filters=16, kernel_size=1, stride=1, name=tfcv.model.util.join("depth-posenc", "1"), config=config)
        # gpoints_depths_posenc = tfcv.model.util.conv(gpoints_depths_posenc, filters=???, kernel_size=1, stride=1, name=tfcv.model.util.join("depth-posenc", "2"), config=config)
        # gpoints_depths_posenc = gpoints_depths_posenc[0]

        bev_filters = [int(s) for s in os.environ["BEV_FILTERS"].split(",")]
        model_params["bev_filters"] = bev_filters

        # Compute initial bev queries
        shape = np.asarray(self.model_constants.bev_final_shape) // self.model_constants.ground_attn_strides[0]
        bev = tf.zeros(shape=[model_input.batches, shape[0], shape[1], bev_filters[0]], dtype="float32")

        # Construct bev positional encoding
        pos_enc_type = str(os.environ["PP_POSENC"]) if "PP_POSENC" in os.environ else "learned"
        assert pos_enc_type == "learned"
        model_params["point-pillar"]["bev"]["pos-enc"]["type"] = pos_enc_type
        if pos_enc_type == "learned":
            # t = int(os.environ["PP_POSENC_DOWN"]) if "PP_POSENC_DOWN" in os.environ else 4
            pos_enc_shape = (self.model_constants.bev_shapes[0][0], self.model_constants.bev_shapes[0][1]) # (self.model_constants.bev_shapes[-1][0] // t, self.model_constants.bev_shapes[-1][1] // t)
            model_params["point-pillar"]["bev"]["pos-enc"]["shape"] = [int(pos_enc_shape[0]), int(pos_enc_shape[1])]
            bev_pos_enc_layer = georeg.model.util.backbone.PositionalEncodingLayer((pos_enc_shape[0], pos_enc_shape[1], bev.shape[-1]))
            def get_bev_pos_enc(bev_shape):
                return bev_pos_enc_layer(bev_shape)
        else:
            assert False











        spatial_mixing = str(os.environ["PP_MIX"]) if "PP_MIX" in os.environ else "none"
        model_params["point-pillar"]["mix"] = spatial_mixing
        def spatial_mix(bev, bev_mask, name):
            if spatial_mixing != "none":
                print("Self attention block")
            if bev_mask is None:
                bev_mask = tf.ones(tf.shape(bev)[:-1], dtype="bool")
            if match := re.match("([0-9]+)convnext-([0-9]+)-([0-9]+)-gp", spatial_mixing):
                for b in range(int(match.group(1))):
                    bev = tfcv.model.convnext.block(
                        bev,
                        factor=int(match.group(3)),
                        kernel_size=int(match.group(2)),
                        name=tfcv.model.util.join(name, f"{b}", "convnext"),
                        config=config,
                    )
                bev_mean = tf.math.divide_no_nan(
                    tfcv.model.einops.apply("b s... f -> b f", tf.where(bev_mask[..., tf.newaxis], bev, 0.0), reduction="sum"),
                    tf.cast(tfcv.model.einops.apply("b s... -> b 1", bev_mask, reduction="count_nonzero"), "float32"),
                ) # b f
                bev = tfcv.model.util.conv(bev, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "gp-1"), config=config)
                bev_mean = tfcv.model.util.conv(bev_mean, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "gp-2"), config=config)
                bev = bev + tfcv.model.einops.apply("b f -> b 1... f", bev_mean, output_ndims=4)
            elif match := re.match("([0-9]+)segformer-sr([0-9]+)-mlp([0-9]+)", spatial_mixing):
                for b in range(int(match.group(1))):
                    bev = tfcv.model.segformer.block(
                        bev,
                        mlp_ratio=int(match.group(3)),
                        kernel_size=3,
                        sr_ratio=int(match.group(2)),
                        heads=4,
                        # shortcut=georeg.model.util.backbone.rezero_shortcut_add,
                        name=tfcv.model.util.join(name, f"{b}", "segformer"),
                        config=config,
                    )
            else:
                assert spatial_mixing == "none"
            return bev

        mask_before_qkv = int(os.environ["PP_MASK_BEFORE"]) == 1 if "PP_MASK_BEFORE" in os.environ else True
        model_params["point-pillar"]["mask-before-qkv"] = mask_before_qkv

        posenc_per_block = int(os.environ["PP_POSENC_PER_BLOCK"]) == 1 if "PP_POSENC_PER_BLOCK" in os.environ else False
        model_params["point-pillar"]["posenc-per-block"] = posenc_per_block

        fuse_specialists = str(os.environ["FUSE_SPECIALISTS"]) if "FUSE_SPECIALISTS" in os.environ else "GAS"
        model_params["point-pillar"]["fuse_specialists"] = fuse_specialists

        mlp_factor = int(os.environ["MLP_FACTOR"]) if "MLP_FACTOR" in os.environ else 2
        model_params["point-pillar"]["mlp_factor"] = mlp_factor

        cross_attention_shortcut = str(os.environ["CROSS_ATTENTION_SHORTCUT"]) if "CROSS_ATTENTION_SHORTCUT" in os.environ else ("1" * len(fuse_specialists))
        model_params["point-pillar"]["cross_attention_shortcut"] = cross_attention_shortcut


        for block in range(self.model_constants.blocks):
            name = tfcv.model.util.join(f"block-{block + 1}")
            stride = int(self.model_constants.aerial_attn_strides[block])
            last_stride = int(self.model_constants.aerial_attn_strides[block - 1 if block > 0 else 0])

            print(f"####### Block {block} at stride {stride} #######")

            if bev.shape[-1] != bev_filters[block]:
                bev = tfcv.model.util.conv(bev, filters=bev_filters[block], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "down"), config=config)

            if posenc_per_block or block == 0:
                bev = bev + get_bev_pos_enc(tf.shape(bev)[-3:-1])

            # bev = georeg.model.util.backbone.resize(bev, last_stride, stride, name=tfcv.model.util.join(name, "bev-up"), config=config)

            def ga(bev):
                return self.ground_attn(
                    bev,
                    ground_featuremaps=ground_featuremaps + [ground_image_features],
                    model_input=model_input,
                    stride=stride,
                    iteration=block,
                    name=tfcv.model.util.join(name, "ground-attn"),
                    config=config,
                )
            def mlp(x, n=None):
                if mlp_factor > 0:
                    return georeg.model.util.backbone.mlp(
                        x,
                        filters=mlp_factor * x.shape[-1],
                        name=tfcv.model.util.join(name, n, "mlp"),
                        # shortcut=rezero,
                        config=config,
                    )
                else:
                    return x

            # Compete
            def g(bev, ca):
                if self.ground_attn.ground_attn.filters_v[block] > 0:
                    a = ga(bev)
                    a = tfcv.model.util.set_name(a, f"{block}-GA")
                    if ca:
                        a = ca * bev + a
                        a = tfcv.model.util.set_name(a, f"{block}-GA-SC")
                    a = mlp(a, n="ground")
                    a = tfcv.model.util.set_name(a, f"{block}-GA-SC-MLP")

                    bev = a
                return bev
            def s(bev):
                if mask_before_qkv and not self.ground_attn.last_bev_mask_origsize is None:
                    bev = tf.where(self.ground_attn.last_bev_mask_origsize[..., tf.newaxis], bev, 0.0)
                bev = spatial_mix(bev, self.ground_attn.last_bev_mask_origsize, tfcv.model.util.join(name, "spatial-mix"))
                bev = tfcv.model.util.set_name(bev, f"{block}-SA")

                if block == self.model_constants.blocks - 1:
                    bev = georeg.model.util.backbone.resize(bev, int(self.model_constants.ground_attn_strides[block]), 1, name="last-upsample-into-aerialattn", config=config)

                return bev
            bev = tfcv.model.util.set_name(bev, f"{block}-INIT")
            for f, c in zip(fuse_specialists.lower(), cross_attention_shortcut):
                c = int(c) == 1
                if f == "g":
                    bev = g(bev, c)
                elif f == "s":
                    assert c == 1
                    bev = s(bev)
                else:
                    assert False
            bev = tfcv.model.util.set_name(bev, f"{block}-FINAL")

        corr_logits, valid_corr = aerial_corr(
            bev=bev,
            bev_mask=self.ground_attn.last_bev_mask,
            aerial=aerial_features,
            corr_shape=model_input.corr_shape,
            name=tfcv.model.util.join(name, "aerial-corr"),
            iteration=block,
            config=config,
        ) # b s... f

        corr_logits = tfcv.model.einops.apply("b h a s... -> b a s...", corr_logits, reduction="mean")
        valid_corr = tfcv.model.einops.apply("b h a s... -> b a s...", valid_corr, reduction="all")

        return corr_logits, valid_corr, preprocess_aerial, preprocess_ground
