import tensorflow as tf
import numpy as np
import os, tfcv, georeg, math, cv2, imageio
import georegdata as grd
# import tensorflow_graphics.geometry.transformation

class SaveGroundDeformPixels:
    def __init__(self, model_constants, loss_params, height_range):
        self.model_constants = model_constants
        self.height_range = height_range

        self.pillar_weights_vis_rescale_pow = 0.3

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {
            # "point-pillar/gpoints-ground-pixels": model.get_layer("point-pillar/gpoints-ground-pixels").output,
            # "point-pillar/gpoints-bev-pixels": model.get_layer("point-pillar/gpoints-bev-pixels").output,
            # "point-pillar/gpoints-bch": model.get_layer("point-pillar/gpoints-bch").output,
        }

        layer_names = [l.name for l in model.layers]
        for i in range(1, self.model_constants.blocks + 1):
            if f"point-pillar/ogpoints-ground-pixels-{i}" in layer_names:
                result[f"point-pillar/ogpoints-ground-pixels-{i}"] = model.get_layer(f"point-pillar/ogpoints-ground-pixels-{i}").output
                result[f"point-pillar/ogpoints-bch-{i}"] = model.get_layer(f"point-pillar/ogpoints-bch-{i}").output
                result[f"pillar-weights-{i}"] = model.get_layer(f"pillar-weights-{i}").output

                j = i
                while (n := f"point-pillar/gpoints-bev-pixels-{j}") not in layer_names:
                    j -= 1
                    assert j > 0
                result[f"point-pillar/gpoints-bev-pixels-{i}"] = model.get_layer(n).output
        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]

        def save(ground_pixels, batches_cameras_heightbins, pillar_weights, name=None):
            name = "" if name is None else f"-{name}"

            if not pillar_weights is None:
                pillar_weights = tf.reduce_mean(pillar_weights, axis=0)

            # ground_pixels: p f
            # batches_cameras_heightbins: p f
            # pillar_weights: p

            if not pillar_weights is None:
                p = tf.shape(ground_pixels)[0]
                pillar_weights = pillar_weights[:p]
                pillar_weights = tfcv.model.einops.apply("p -> p 1", pillar_weights)

            points_mask = tf.math.reduce_all(tf.math.logical_and(
                ground_pixels >= 0.0,
                ground_pixels <= tf.cast(tf.shape(model_input.ground_images)[-3:-1][tf.newaxis, :] - 1, "float32"),
            ), axis=-1)
            ground_pixels = tf.boolean_mask(ground_pixels, points_mask, axis=0)
            batches_cameras_heightbins = tf.boolean_mask(batches_cameras_heightbins, points_mask, axis=0)
            if not pillar_weights is None:
                pillar_weights = tf.boolean_mask(pillar_weights, points_mask, axis=0)

            mask = tf.scatter_nd(
                tf.concat([batches_cameras_heightbins[..., :2], tf.cast(ground_pixels, "int32")[..., :2]], axis=-1),
                tf.ones(tf.shape(ground_pixels[:, :1]), dtype="int32"),
                tf.shape(model_input.ground_images[..., :1]),
            ) > 0 # b c s... 1

            height_bins = tf.zeros(tf.shape(model_input.ground_images[..., :1]), dtype="int32")
            height_bins = tf.tensor_scatter_nd_max(
                height_bins,
                tf.concat([batches_cameras_heightbins[..., :2], tf.cast(ground_pixels, "int32")[..., :2]], axis=-1),
                batches_cameras_heightbins[:, -1:],
            ) # b c s... 1

            if not pillar_weights is None:
                weights = tf.zeros(tf.shape(model_input.ground_images[..., :1]), dtype="float32")
                weights = tf.tensor_scatter_nd_max(
                    weights,
                    tf.concat([batches_cameras_heightbins[..., :2], tf.cast(ground_pixels, "int32")[..., :2]], axis=-1),
                    pillar_weights,
                ) # b c s... 1

            for b in range(model_input.batches):
                for c in range(model_input.ground_images_num[b]):
                    image = np.where(model_input.ground_images_mask[b, c, ..., tf.newaxis], model_input.ground_images[b, c], 0.0)
                    m = mask[b, c].numpy()

                    height_color = height_bins[b, c, ..., 0].numpy() # s...
                    height_color = colormap[(height_color.astype("float32") / (self.height_range[2] - 1) * 255).astype("uint8")]

                    if pillar_weights is None:
                        color = np.where(m, height_color, image)
                    else:
                        weight_alpha = weights[b, c, ..., 0].numpy() # s...
                        weight_alpha = weight_alpha[:, :, np.newaxis]
                        weight_alpha = np.clip(weight_alpha, 0.0, 1.0)
                        weight_alpha = np.power(weight_alpha, self.pillar_weights_vis_rescale_pow)
                        color = np.where(
                            m,
                            height_color * weight_alpha + (1 - weight_alpha) * image,
                            image,
                        )

                    imageio.imwrite(os.path.join(path, f"{names[b]}-grounddeform-c{c}{name}.jpg"), color.astype("uint8"))

        # p = tf.shape(model_output[f"point-pillar/ogpoints-ground-pixels-1"])[0]
        #
        # point_indices = tf.random.shuffle(tf.range(p))[:10].numpy()
        # def save_individual(ground_pixels, bev_pixels, batches_cameras_heightbins, pillar_weights, block, name):
        #     if ground_pixels.shape[0] == 0:
        #         return
        #     if not pillar_weights is None:
        #         pillar_weights = tf.reduce_mean(pillar_weights, axis=0)
        #
        #     # ground_pixels: p f
        #     # bev_pixels: p f
        #     # batches_cameras_heightbins: p f
        #     # pillar_weights: p
        #
        #     # if not pillar_weights is None:
        #     #     h = tf.shape(ground_pixels)[0]
        #     #     r = tf.shape(ground_pixels)[2]
        #     #     pillar_weights = tfcv.model.einops.apply("p -> h p r", pillar_weights, h=h, r=r)
        #
        #     aerial_pixels, _ = georeg.model.util.project.bevpixels_to_aerialpixels_onimage(
        #         bev_shape=self.model_constants.bev_shapes[-1] // self.model_constants.ground_attn_strides[block],
        #         aerial_shape=self.model_constants.aerial_image_shape,
        #         loss_input=loss_input,
        #         bev_pixels=bev_pixels[tf.newaxis],
        #         model_constants=self.model_constants,
        #     )
        #     aerial_pixels = aerial_pixels[0] # p f
        #
        #     colors = batches_cameras_heightbins[:, 2].numpy()
        #     colors = colormap[(colors.astype("float32") / (self.height_range[2] - 1) * 255).astype("uint8")]
        #
        #     for point_index in point_indices:
        #         matching_point_indices = np.where(np.all(bev_pixels == bev_pixels[point_index][np.newaxis], axis=-1))[0]
        #
        #         b = batches_cameras_heightbins[point_index, 0].numpy()
        #         aerial_image = np.copy(model_input.aerial_image[b])
        #         path2 = os.path.join(path, f"{names[b]}-indgrounddeform")
        #         if not os.path.isdir(path2):
        #             os.makedirs(path2)
        #         name2 = f"p{point_index}"
        #
        #         grd.visualize.draw_points(aerial_image, aerial_pixels[point_index].numpy(), np.asarray([255, 0, 0]), radius=4.0)
        #         imageio.imwrite(os.path.join(path2, f"{name2}-aerial.jpg"), aerial_image.astype("uint8"))
        #
        #         for c in range(model_input.ground_images_num[b]):
        #             ground_image = np.where(model_input.ground_images_mask[b, c, ..., tf.newaxis], model_input.ground_images[b, c], 0.0)
        #
        #             save = False
        #             for matching_point_index in matching_point_indices:
        #                 if batches_cameras_heightbins[matching_point_index, 1] == c:
        #                     ground_pixels_c = ground_pixels[matching_point_index].numpy()
        #
        #                     if pillar_weights is None:
        #                         colors_c = colors[:, matching_point_index].reshape([-1, 3])
        #                     else:
        #                         weight = pillar_weights[matching_point_index].numpy()
        #                         weight = np.power(weight, self.pillar_weights_vis_rescale_pow)
        #                         colors_c = colormap[(weight.astype("float32") * 255).astype("uint8")]
        #
        #                     grd.visualize.draw_points(ground_image, ground_pixels_c, colors_c, radius=1.1)
        #                     save = True
        #             if save:
        #                 imageio.imwrite(os.path.join(path2, f"{name2}-c{c}-{name}.jpg"), ground_image.astype("uint8"))

        for i in range(1, self.model_constants.blocks + 1):
            if f"point-pillar/ogpoints-ground-pixels-{i}" in model_output:
                bev_pixels = model_output[f"point-pillar/gpoints-bev-pixels-{i}"]

                save(model_output[f"point-pillar/ogpoints-ground-pixels-{i}"], model_output[f"point-pillar/ogpoints-bch-{i}"], model_output[f"pillar-weights-{i}"], name=f"offset{i}")
                # save_individual(model_output[f"point-pillar/ogpoints-ground-pixels-{i}"], bev_pixels, model_output[f"point-pillar/ogpoints-bch-{i}"], model_output[f"pillar-weights-{i}"], i - 1, name=f"offset{i}")

class DrawBEV:
    def __init__(self, model_constants, loss_params, height_range):
        self.model_constants = model_constants
        self.height_range = height_range

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {}

        layer_names = [l.name for l in model.layers]
        result[f"initial-bev-mask"] = model.get_layer(f"initial-bev-mask").output
        for i in range(1, self.model_constants.blocks + 1):
            if f"point-pillar/ogpoints-ground-pixels-{i}" in layer_names:
                result[f"point-pillar/ogpoints-bch-{i}"] = model.get_layer(f"point-pillar/ogpoints-bch-{i}").output
                result[f"pillar-weights-{i}"] = model.get_layer(f"pillar-weights-{i}").output

                j = i
                while (n := f"point-pillar/gpoints-bev-pixels-{j}") not in layer_names:
                    j -= 1
                    assert j > 0
                result[f"point-pillar/gpoints-bev-pixels-{i}"] = model.get_layer(n).output

        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]

        inital_bev_mask = model_output[f"initial-bev-mask"].numpy()
        for b in range(model_input.batches):
            image = inital_bev_mask[b][..., np.newaxis] * 255
            imageio.imwrite(os.path.join(path, f"{names[b]}-bev-mask.jpg"), image.astype("uint8"))

        # for b, frame in enumerate(frames):
        #     print(f"X1 {names[b]} {frame.ground_frame.ego_to_world} {frame.ground_frame.ego_to_pixels} {frame.ground_frame.latlon}")

        def save(batches_cameras_heightbins, pillar_weights, bev_pixels, iteration):
            # batches_cameras_heightbins: p 3
            # pillar_weights: h p
            # bev_pixels: p 2

            iteration_name = f"block{iteration}"

            def save(pillar_weights, head_name):
                bev_shape = self.model_constants.bev_shapes[-1] / self.model_constants.ground_attn_strides[iteration]
                pillar_weights = tf.scatter_nd(
                    tf.concat([tf.cast(batches_cameras_heightbins[:, :1], "int32"), tf.cast(batches_cameras_heightbins[:, -1:], "int32"), tf.cast(bev_pixels, "int32")], axis=-1),
                    pillar_weights[..., tf.newaxis],
                    [model_input.batches, self.height_range[2], bev_shape[0], bev_shape[1], 1],
                )[..., 0] # b height s...

                aerial_shape = tf.shape(model_input.aerial_image)[-3:-1]
                aerialpixels = georeg.model.util.project.generate_bev_pixels(model_input.batches, aerial_shape) # b s... 2
                aerialpixels_onbev, aerialpixels_onbev_mask = georeg.model.util.project.aerialpixels_to_bevpixels_onimage(aerial_shape, bev_shape, loss_input, aerialpixels, self.model_constants)

                pillar_weights = tf.gather_nd(
                    tfcv.model.einops.apply("b h s... -> b s... h", pillar_weights),
                    tf.cast(aerialpixels_onbev, "int32"), # b p 2
                    batch_dims=1,
                ) # b (s...) h
                pillar_weights = tf.where(aerialpixels_onbev_mask[..., tf.newaxis], pillar_weights, 0.0)
                pillar_weights = tfcv.model.einops.apply("b (s...) h -> b h s...", pillar_weights, s=aerial_shape)
                pillar_weights = tf.math.divide_no_nan(pillar_weights, tfcv.model.einops.apply("b h s... -> b 1 s...", pillar_weights, reduction="sum"))

                num_valid_heights = tf.math.count_nonzero(pillar_weights, axis=1)

                # Entropy per pillar
                entropy = tf.where(pillar_weights > 0, -pillar_weights * tf.math.log(pillar_weights), 0.0)
                entropy = tf.reduce_sum(entropy, axis=1)
                entropy = tf.math.divide_no_nan(
                    entropy,
                    tf.where(num_valid_heights > 0, tf.math.log(tf.cast(num_valid_heights, "float32")), 0.0),
                )
                entropy = tf.math.divide_no_nan(entropy, tfcv.model.einops.apply("b s... -> b 1...", entropy, output_ndims=3, reduction="max"))
                for b in range(model_input.batches):
                    color = colormap[(entropy[b].numpy() * 255.0).astype("uint8")]
                    aerial_image = model_input.aerial_image[b]

                    s = 0.5
                    image = s * aerial_image + (1 - s) * color
                    imageio.imwrite(os.path.join(path, f"{names[b]}-pillarentropy-{iteration_name}-{head_name}.jpg"), image.astype("uint8"))

                # Estimated height
                heights = tf.linspace(self.height_range[0], self.height_range[1], self.height_range[2])
                height = tfcv.model.einops.apply("h, b h s... -> b s...", heights, pillar_weights)
                for b in range(model_input.batches):
                    color = colormap[(georeg.model.util.rescale(height[b]).numpy() * 255.0).astype("uint8")]
                    aerial_image = model_input.aerial_image[b]

                    s = 0.5
                    image = s * aerial_image + (1 - s) * color
                    imageio.imwrite(os.path.join(path, f"{names[b]}-weightedheight-{iteration_name}-{head_name}.jpg"), image.astype("uint8"))

            for h in range(pillar_weights.shape[0]):
                save(pillar_weights[h], f"head{h}")
            save(tf.reduce_mean(pillar_weights, axis=0), f"headmean")
            save(tf.reduce_max(pillar_weights, axis=0), f"headmax")

        for i in range(1, self.model_constants.blocks + 1):
            if f"point-pillar/ogpoints-ground-pixels-{i}" in model_output:
                save(
                    model_output[f"point-pillar/ogpoints-bch-{i}"],
                    model_output[f"pillar-weights-{i}"],
                    model_output[f"point-pillar/gpoints-bev-pixels-{i}"],
                    iteration=i - 1,
                )

class PointGrid3D:
    def __init__(self, stride, bev_shape, meters_per_pixel, height_range, model_input, ego_to_fixedego, iteration, model_constants, model_params):
        self.stride = stride
        self.height_range = height_range
        self.meters_per_pixel = meters_per_pixel

        heights = tf.cast(tf.linspace(self.height_range[0], self.height_range[1], self.height_range[2] * (tf.shape(model_input.ground_images)[0] * 0 + 1)), "float32")
        heights = tfcv.model.einops.apply("h -> b s... h", heights, b=model_input.batches, s=bev_shape)

        bev_pixels = georeg.model.util.project.generate_bev_pixels(model_input.batches, bev_shape)

        # Transform bev pixels to screen
        half_offset = int(os.environ["PP_BEVPIXEL_HALF_OFFSET"]) == 1 if "PP_BEVPIXEL_HALF_OFFSET" in os.environ else True
        model_params["half-offset"] = half_offset
        half_offset = 0.5 if half_offset else 0.0

        bev_pixels_3d = tf.concat([
            tfcv.model.einops.apply("b s... f -> b s... h f", tf.cast(bev_pixels, "float32") + half_offset, h=self.height_range[2]),
            heights[..., tf.newaxis] / meters_per_pixel,
        ], axis=-1)

        gpoints_ground_pixels, gpoints_depths, gpoints_mask = georeg.model.util.project.bevpixels_to_groundpixels_onimage(
            bev_shape=bev_shape,
            model_constants=model_constants,
            model_input=model_input,
            ego_to_fixedego=ego_to_fixedego,
            points=bev_pixels_3d,
        )
        p = tf.shape(gpoints_mask)[2]
        gpoints_ground_pixels = tfcv.model.einops.apply("b c p f -> (b c p) f", gpoints_ground_pixels)
        gpoints_depths = tfcv.model.einops.apply("b c p -> (b c p)", gpoints_depths)
        gpoints_mask = tfcv.model.einops.apply("b c p -> (b c p)", gpoints_mask)
        gpoints_bev_pixels = tfcv.model.einops.apply("b s... f -> (b c s... h) f", bev_pixels, s=bev_shape, h=self.height_range[2], c=model_input.max_cameras)
        gpoints_batches_cameras_heightbins = tf.stack([
            tfcv.model.einops.apply("b -> (b c s... h)", tf.range(model_input.batches), s=bev_shape, h=self.height_range[2], c=model_input.max_cameras),
            tfcv.model.einops.apply("c -> (b c s... h)", tf.range(model_input.max_cameras), s=bev_shape, h=self.height_range[2], b=model_input.batches),
            tfcv.model.einops.apply("h -> (b c s... h)", tf.range(self.height_range[2]), s=bev_shape, h=self.height_range[2], c=model_input.max_cameras, b=model_input.batches),
        ], axis=-1)

        bev_mask = tfcv.model.einops.apply("(b c s... h) -> b s...", gpoints_mask, reduction="any", s=bev_shape, h=self.height_range[2], c=model_input.max_cameras, b=model_input.batches)

        p = georeg.model.util.project.generate_bev_pixels(model_input.batches, tf.shape(bev_mask)[-2:])
        p = tf.cast(p, "float32")
        radius = tf.cast(bev_shape, p.dtype)[tf.newaxis, tf.newaxis, tf.newaxis, :] / 2
        p = (p - radius) / radius
        bev_mask = tf.math.logical_and(
            bev_mask,
            tf.reduce_sum(tf.math.square(p), axis=-1) <= 1.0,
        )
        bev_mask = tfcv.model.util.set_name(bev_mask, f"initial-bev-mask")
        gpoints_mask = tf.math.logical_and(
            tfcv.model.einops.apply("b s... -> (b c s... h)", bev_mask, s=bev_shape, h=self.height_range[2], c=model_input.max_cameras, b=model_input.batches),
            gpoints_mask,
        )

        gpoints_ground_pixels = tf.boolean_mask(gpoints_ground_pixels, gpoints_mask, axis=0)
        gpoints_bev_pixels = tf.boolean_mask(gpoints_bev_pixels, gpoints_mask, axis=0)
        gpoints_depths = tf.boolean_mask(gpoints_depths, gpoints_mask, axis=0)
        gpoints_batches_cameras_heightbins = tf.boolean_mask(gpoints_batches_cameras_heightbins, gpoints_mask, axis=0)

        gpoints_ground_pixels = tfcv.model.util.set_name(gpoints_ground_pixels, f"point-pillar/gpoints-ground-pixels-{iteration + 1}")
        gpoints_bev_pixels = tfcv.model.util.set_name(gpoints_bev_pixels, f"point-pillar/gpoints-bev-pixels-{iteration + 1}")
        gpoints_batches_cameras_heightbins = tfcv.model.util.set_name(gpoints_batches_cameras_heightbins, f"point-pillar/gpoints-bch-{iteration + 1}")

        self.gpoints_ground_pixels = gpoints_ground_pixels
        self.gpoints_bev_pixels = gpoints_bev_pixels
        self.gpoints_batches_cameras_heightbins = gpoints_batches_cameras_heightbins
        self.gpoints_depths = gpoints_depths
        self.bev_mask = bev_mask


class GatherFeatures:
    def __init__(self, blocks, model_constants, model_params):
        self.blocks = blocks
        self.model_constants = model_constants

        self.shortcut_offsets = int(os.environ["PP_GG_SHORTCUT_OFFSETS"]) == 1 if "PP_GG_SHORTCUT_OFFSETS" in os.environ else True
        model_params["shortcut-offsets"] = self.shortcut_offsets
        self.deform_factor = float(os.environ["PP_GG_DEFORM_FACTOR"]) if "PP_GG_DEFORM_FACTOR" in os.environ else 0.0
        model_params["deform-factor"] = self.deform_factor

        self.keyvalue_source = str(os.environ["PP_GG_KEYVALUE_SOURCE"]) if "PP_GG_KEYVALUE_SOURCE" in os.environ else "decoder"
        model_params["keyvalue-source"] = self.keyvalue_source

        self.last_ogpoints_offsets = None

    def gather(self, bev, ground_featuremaps, point_grid_3d, filters, model_input, name, iteration, config):
        focal_length = 0.5 * (model_input.ground_intr[:, :, 0, 0] + model_input.ground_intr[:, :, 1, 1]) # b c

        ogpoints_batches_cameras_heightbins = point_grid_3d.gpoints_batches_cameras_heightbins # p f
        ogpoints_batches_cameras_heightbins = tfcv.model.util.set_name(ogpoints_batches_cameras_heightbins, f"point-pillar/ogpoints-bch-{iteration + 1}")

        if self.deform_factor > 0:
            ogpoints_offsets = tfcv.model.util.conv(bev, filters=2 * point_grid_3d.height_range[2], kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "offsets"), config=config)
            layer = ogpoints_offsets.node.layer
            layer.set_weights([layer.get_weights()[0] * 0, layer.get_weights()[1] * 0])
            ogpoints_offsets = tfcv.model.einops.apply("b s... (height f) -> b s... height f", ogpoints_offsets, height=point_grid_3d.height_range[2], f=2)

            if self.shortcut_offsets:
                if not self.last_ogpoints_offsets is None:
                    last = tfcv.model.einops.apply("b s... height f -> b s... (height f)", self.last_ogpoints_offsets, f=2)
                    last = tfcv.model.util.resize(last, tf.shape(ogpoints_offsets)[1:3], method="bilinear", config=config)
                    last = tfcv.model.einops.apply("b s... (height f) -> b s... height f", last, f=2)

                    ogpoints_offsets = ogpoints_offsets + last
                self.last_ogpoints_offsets = ogpoints_offsets

            ogpoints_offsets = tf.gather_nd(
                ogpoints_offsets,
                tf.concat([
                    point_grid_3d.gpoints_batches_cameras_heightbins[:, :1],
                    point_grid_3d.gpoints_bev_pixels,
                    point_grid_3d.gpoints_batches_cameras_heightbins[:, -1:],
                ], axis=-1),
            ) # p 2

            ogpoints_focal_length = tf.gather_nd(
                focal_length[..., tf.newaxis],
                ogpoints_batches_cameras_heightbins[..., :2],
            ) # p 1

            ogpoints_offsets = ogpoints_offsets * (ogpoints_focal_length * (self.deform_factor * point_grid_3d.meters_per_pixel))



        def bilinear_sample(ogpoints_ground_pixels, image_features):
            ogpoints_ground_pixels_lower = tf.stop_gradient(tf.math.maximum(tf.cast(tf.math.floor(ogpoints_ground_pixels), "int32"), 0))
            ogpoints_ground_pixels_upper = tf.stop_gradient(ogpoints_ground_pixels_lower + 1)
            ogpoints_ground_pixels_alpha = tf.clip_by_value(ogpoints_ground_pixels - tf.cast(ogpoints_ground_pixels_lower, ogpoints_ground_pixels.dtype), 0.0, 1.0)

            ogpoints_ground_pixels_corners = tf.stack([
                ogpoints_ground_pixels_lower,
                tf.stack([ogpoints_ground_pixels_lower[..., 0], ogpoints_ground_pixels_upper[..., 1]], axis=-1),
                tf.stack([ogpoints_ground_pixels_upper[..., 0], ogpoints_ground_pixels_lower[..., 1]], axis=-1),
                ogpoints_ground_pixels_upper,
            ], axis=0) # corners p 2

            ogpoints_batches_cameras_corners = tfcv.model.einops.apply("p f -> corners p f", ogpoints_batches_cameras_heightbins[..., :2], corners=4)

            ogpoints_features_at_corners = tf.gather_nd(
                image_features,
                tf.concat([
                    tfcv.model.einops.apply("corners p f -> (corners p) f", ogpoints_batches_cameras_corners),
                    tfcv.model.einops.apply("corners p f -> (corners p) f", ogpoints_ground_pixels_corners),
                ], axis=-1),
                batch_dims=0,
            )
            ogpoints_features_at_corners = tfcv.model.einops.apply("(corners p) f -> corners p f", ogpoints_features_at_corners, corners=4)

            ogpoints_features00 = ogpoints_features_at_corners[0] # p 2
            ogpoints_features01 = ogpoints_features_at_corners[1]
            ogpoints_features10 = ogpoints_features_at_corners[2]
            ogpoints_features11 = ogpoints_features_at_corners[3]

            alpha = ogpoints_ground_pixels_alpha # p 2
            # w00 = (1 - alpha[..., 0]) * (1 - alpha[..., 1])
            # w01 = (1 - alpha[..., 0]) * (    alpha[..., 1])
            # w10 = (    alpha[..., 0]) * (1 - alpha[..., 1])
            # w11 = (    alpha[..., 0]) * (    alpha[..., 1])
            # ogpoints_features = w00[..., tf.newaxis] * ogpoints_features00 + w01[..., tf.newaxis] * ogpoints_features01 + w10[..., tf.newaxis] * ogpoints_features10 + w11[..., tf.newaxis] * ogpoints_features11

            # ogpoints_features0 = ogpoints_features00 + alpha[..., 1:2] * (ogpoints_features01 - ogpoints_features00)
            # ogpoints_features1 = ogpoints_features10 + alpha[..., 1:2] * (ogpoints_features11 - ogpoints_features10)
            # ogpoints_features  = ogpoints_features1  + alpha[..., 0:1] * (ogpoints_features1  - ogpoints_features0 ) # p 2

            ogpoints_features0 = ogpoints_features00 * (1 - alpha[..., 1:2]) + ogpoints_features01 * alpha[..., 1:2]
            ogpoints_features1 = ogpoints_features10 * (1 - alpha[..., 1:2]) + ogpoints_features11 * alpha[..., 1:2]
            ogpoints_features  = ogpoints_features0  * (1 - alpha[..., 0:1]) + ogpoints_features1  * alpha[..., 0:1] # p 2

            return ogpoints_features

        if self.deform_factor > 0:
            ogpoints_ground_pixels = point_grid_3d.gpoints_ground_pixels + ogpoints_offsets
        else:
            ogpoints_ground_pixels = point_grid_3d.gpoints_ground_pixels
        ogpoints_ground_pixels = tfcv.model.util.set_name(ogpoints_ground_pixels, f"point-pillar/ogpoints-ground-pixels-{iteration + 1}")

        if self.keyvalue_source == "decoded":
            x = tfcv.model.util.conv(ground_featuremaps[-1], filters=filters, kernel_size=1, stride=1, name=tfcv.model.util.join(name, "image-features"), config=config)

            factor = tf.cast(tf.shape(x)[-3:-1], "float32") / tf.cast(tf.shape(model_input.ground_images)[-3:-1], "float32")
            ogpoints_features = bilinear_sample(ogpoints_ground_pixels * factor[tf.newaxis, :], x)
        # elif self.keyvalue_source == "multiscale-concat":
        #     ground_featuremaps = ground_featuremaps[:-1]
        #
        #     ogpoints_features = []
        #     levels = len(ground_featuremaps)
        #     for level, featuremap in enumerate(ground_featuremaps):
        #         featuremap = tfcv.model.util.conv(featuremap, filters=filters // levels, kernel_size=1, stride=1, name=tfcv.model.util.join(name, f"level{level + 1}-features"), config=config)
        #         featuremap = tfcv.model.einops.apply("b c s... (h f) -> h b c s... f", featuremap, h=self.heads)
        #
        #         factor = tf.cast(tf.shape(featuremap)[-3:-1], "float32") / tf.cast(tf.shape(model_input.ground_images)[-3:-1], "float32")
        #         ogpoints_features.append(bilinear_sample(ogpoints_ground_pixels * factor[tf.newaxis, tf.newaxis, tf.newaxis, :], featuremap))
        #     ogpoints_features = tf.concat(ogpoints_features, axis=-1)
        # elif self.keyvalue_source == "multiscale-add":
        #     ground_featuremaps = ground_featuremaps[:-1]
        #
        #     ogpoints_features = []
        #     levels = len(ground_featuremaps)
        #     for level, featuremap in enumerate(ground_featuremaps):
        #         featuremap = tfcv.model.util.conv(featuremap, filters=filters, kernel_size=1, stride=1, name=tfcv.model.util.join(name, f"level{level + 1}-features"), config=config)
        #         featuremap = tfcv.model.einops.apply("b c s... (h f) -> h b c s... f", featuremap, h=self.heads)
        #
        #         factor = tf.cast(tf.shape(featuremap)[-3:-1], "float32") / tf.cast(tf.shape(model_input.ground_images)[-3:-1], "float32")
        #         ogpoints_features.append(bilinear_sample(ogpoints_ground_pixels * factor[tf.newaxis, tf.newaxis, tf.newaxis, :], featuremap))
        #     ogpoints_features = tf.tf.math.add_n(ogpoints_features)
        # elif self.keyvalue_source == "multiscale-low-to-high-res":
        #     ground_featuremaps = ground_featuremaps[:-1]
        #
        #     ogpoints_features = []
        #     levels = len(ground_featuremaps)
        #     assert self.blocks <= levels
        #
        #     featuremap = tfcv.model.util.conv(ground_featuremaps[self.blocks - 1 - iteration], filters=filters, kernel_size=1, stride=1, name=tfcv.model.util.join(name, f"image-features"), config=config)
        #     featuremap = tfcv.model.einops.apply("b c s... (h f) -> h b c s... f", featuremap, h=self.heads)
        #
        #     factor = tf.cast(tf.shape(featuremap)[-3:-1], "float32") / tf.cast(tf.shape(model_input.ground_images)[-3:-1], "float32")
        #     ogpoints_features = bilinear_sample(ogpoints_ground_pixels * factor[tf.newaxis, tf.newaxis, tf.newaxis, :], featuremap)
        else:
            assert False

        return ogpoints_features

class TransformerBlock:
    def __init__(self, model_params):
        self.type = str(os.environ["PP_TYPE"]) if "PP_TYPE" in os.environ else "qkv"
        model_params["outer"]["type"] = self.type
        self.shortcut_logits = int(os.environ["PP_SHORTCUT_LOGITS"]) == 1 if "PP_SHORTCUT_LOGITS" in os.environ else True
        model_params["outer"]["shortcut-logits"] = self.shortcut_logits
        self.filters_v = [int(s) for s in os.environ["GROUND_ATTN_FILTERS_V"].split(",")]
        model_params["outer"]["filters-v"] = self.filters_v
        self.heads = int(os.environ["PP_HEADS"]) if "PP_HEADS" in os.environ else 1
        model_params["outer"]["heads"] = self.heads

        self.add_zero_kv = int(os.environ["PP_ZERO_KV"]) == 1 if "PP_ZERO_KV" in os.environ else False
        model_params["outer"]["zero-kv"] = self.add_zero_kv

        if self.type == "qkv":
            self.filters_qk = int(os.environ["PP_FILTERS_QK"])
            model_params["outer"]["filters-qk"] = self.filters_qk
        elif self.type.startswith("wv"):
            pass
        else:
            assert False

        self.last_logits = None

    def __call__(self, x, ground_gather, ground_featuremaps, ground_point_grid, model_input, name, iteration, config):
        x_orig = x

        x = tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "norm"), config=config)

        # Ground gather
        gpoints_features = ground_gather.gather(
            bev=x,
            ground_featuremaps=ground_featuremaps,
            point_grid_3d=ground_point_grid,
            filters=self.filters_v[iteration],
            model_input=model_input,
            name=tfcv.model.util.join(name, "gather-ground"),
            iteration=iteration,
            config=config,
        ) # p f
        gpoints_features = tfcv.model.util.conv(gpoints_features, filters=self.filters_v[iteration], kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "inner", "out"), config=config) # p f

        # Pillar reduce
        if self.type == "qkv":
            points = tf.concat([
                ground_point_grid.gpoints_batches_cameras_heightbins[:, :1],
                tf.cast(ground_point_grid.gpoints_bev_pixels, "int32"),
                ground_point_grid.gpoints_batches_cameras_heightbins[:, -1:],
            ], axis=-1)

            query = tfcv.model.util.conv(x, filters=ground_point_grid.height_range[2] * self.filters_qk, kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "query"), config=config)
            query = tfcv.model.einops.apply("b s... (height f) -> b s... height f", query, height=ground_point_grid.height_range[2])
            query = tf.gather_nd(
                query,
                points,
                batch_dims=0,
            ) # p f


            key = gpoints_features
            value = gpoints_features

            key = tfcv.model.util.conv(key, filters=self.filters_qk, kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "key"), config=config)
            value = tfcv.model.util.conv(value, filters=self.filters_v[iteration], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "value"), config=config)


            points = tf.concat([
                ground_point_grid.gpoints_batches_cameras_heightbins[:, :1],
                tf.cast(ground_point_grid.gpoints_bev_pixels, "int32"),
            ], axis=-1)
            shape = tf.shape(x)[:3]

            if self.add_zero_kv:
                new_points = tf.cast(tf.where(ground_point_grid.bev_mask), "int32")
                points = tf.concat([points, new_points], axis=0)

                zeros_qk = tf.zeros([tf.shape(new_points)[0], self.filters_qk])
                query = tf.concat([query, zeros_qk], axis=0)
                key = tf.concat([key, zeros_qk], axis=0)

                zeros_v = tf.zeros([tf.shape(new_points)[0], self.filters_v[iteration]])
                value = tf.concat([value, zeros_v], axis=0)

            query = tfcv.model.einops.apply("p (h f) -> h p f", query, h=self.heads)
            key = tfcv.model.einops.apply("p (h f) -> h p f", key, h=self.heads)
            value = tfcv.model.einops.apply("p (h f) -> h p f", value, h=self.heads)

            # query = query + gpoints_depths_posenc
            # keyvalue = keyvalue + (gpoints_world_pos_enc + gpoints_cam_pos_enc) # + make_gpoints_height_posenc() #  + gpoints_intr_enc

            def logits_forward(logits):
                if self.shortcut_logits:
                    if not self.last_logits is None:
                        last_logits = self.last_logits
                        last_logits = tfcv.model.util.ScaleLayer(axis=[], initial_value=1.0, name=tfcv.model.util.join(name, "shortcut-logits-scale"))(last_logits)

                        # logits = tfcv.model.util.ScaleLayer(axis=[], initial_value=0.0, name=tfcv.model.util.join(name, "shortcut-rezero"))(logits)

                        logits = logits + last_logits
                    self.last_logits = logits

                    variance = (iteration + 1)
                    logits = logits * (tf.cast(variance, "float32") ** -0.5)

                logits = tfcv.model.util.ScaleLayer(axis=[1], initial_value=1.0, name=tfcv.model.util.join(name, "logits-scale"))(logits)
                return logits

            x = georeg.model.util.project.scatter_features_on_image_qkv_unbatched(
                query=query,
                key=key,
                value=value,
                points=points,
                shape=shape,
                iteration=iteration,
                logits_forward=logits_forward,
            )
            x = tfcv.model.einops.apply("b s... h f -> b s... (h f)", x, h=self.heads, f=self.filters_v[iteration] // self.heads)

            x = tfcv.model.util.conv(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "out"), config=config)

        elif self.type == "wv":
            points = tf.concat([
                ground_point_grid.gpoints_batches_cameras_heightbins[:, :1],
                tf.cast(ground_point_grid.gpoints_bev_pixels, "int32"),
                ground_point_grid.gpoints_batches_cameras_heightbins[:, -1:],
            ], axis=-1)

            weights = tfcv.model.util.conv(x, filters=ground_point_grid.height_range[2] * self.heads, kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "weights"), config=config)
            # b s... (height head)

            if self.shortcut_logits:
                if not self.last_logits is None:
                    last_logits = self.last_logits
                    last_logits = tfcv.model.util.ScaleLayer(axis=[], initial_value=1.0, name=tfcv.model.util.join(name, "shortcut-logits-scale"))(last_logits)

                    h = tf.shape(last_logits)[0]
                    r = tf.shape(last_logits)[1]
                    b = tf.shape(last_logits)[2]
                    last_logits = tfcv.model.util.resize(last_logits, tf.shape(weights)[1:3], method="bilinear", config=config)

                    # weights = tfcv.model.util.ScaleLayer(axis=[], initial_value=0.0, name=tfcv.model.util.join(name, "shortcut-rezero"))(weights)

                    weights = weights + last_logits
                self.last_logits = weights

                variance = (iteration + 1)
                weights = weights * (tf.cast(variance, "float32") ** -0.5)

            weights = tfcv.model.einops.apply("b s... (height head) -> b s... height head", weights, height=ground_point_grid.height_range[2], head=self.heads)
            weights = tf.gather_nd(
                weights,
                points,
                batch_dims=0,
            ) # p head
            weights = tfcv.model.util.ScaleLayer(axis=[], initial_value=1.0, name=tfcv.model.util.join(name, "logits-scale"))(weights)

            value = gpoints_features
            value = tfcv.model.util.conv(value, filters=self.filters_v[iteration], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "value"), config=config)

            points = tf.concat([
                ground_point_grid.gpoints_batches_cameras_heightbins[:, :1],
                tf.cast(ground_point_grid.gpoints_bev_pixels, "int32"),
            ], axis=-1)
            shape = tf.shape(x)[:3]

            if self.add_zero_kv:
                new_points = tf.cast(tf.where(ground_point_grid.bev_mask), "int32")
                points = tf.concat([points, new_points], axis=0)

                zeros_w = tf.zeros([tf.shape(new_points)[0], self.heads])
                weights = tf.concat([weights, zeros_w], axis=0)

                zeros_v = tf.zeros([tf.shape(new_points)[0], self.filters_v[iteration]])
                value = tf.concat([value, zeros_v], axis=0)

            weights = tfcv.model.einops.apply("p head -> head p", weights, head=self.heads)
            value = tfcv.model.einops.apply("p (head f) -> head p f", value, head=self.heads)

            x = georeg.model.util.project.scatter_features_on_image_wv_unbatched(
                weights=weights,
                value=value,
                points=points,
                shape=shape,
                iteration=iteration,
            )
            x = tfcv.model.einops.apply("b s... head f -> b s... (head f)", x, head=self.heads, f=self.filters_v[iteration] // self.heads)

            x = tfcv.model.util.conv(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, name=tfcv.model.util.join(name, "outer", "out"), config=config)
        else:
            assert False

        return x

class GroundAttn:
    def __init__(self, model_constants, model_params):
        self.model_constants = model_constants
        self.model_params = model_params

        self.height_range = (
            float(os.environ["PP_HEIGHT_MIN"]) if "PP_HEIGHT_MIN" in os.environ else -2.0,
            float(os.environ["PP_HEIGHT_MAX"]) if "PP_HEIGHT_MAX" in os.environ else 15.0,
            int(os.environ["PP_HEIGHT_NUM"]) if "PP_HEIGHT_NUM" in os.environ else 32,
        )
        model_params["height"]["min"] = self.height_range[0]
        model_params["height"]["max"] = self.height_range[1]
        model_params["height"]["num"] = self.height_range[2]

        self.fix_transform = int(os.environ["FIX_TRANSFORM"]) == 1 if "FIX_TRANSFORM" in os.environ else False
        model_params["fix_transform"] = self.fix_transform

        self.ground_gather = GatherFeatures(
            blocks=self.model_constants.blocks,
            model_constants=model_constants,
            model_params=model_params["ground-point-grid"]["gather"],
        )

        self.ground_attn = TransformerBlock(
            model_params=model_params["ground-attn"],
        )

        self.last_bev_mask = None
        self.last_bev_mask_origsize = None
        self.last_ground_point_grid = None

        self.last_up_vectors = None
        self.last_height = None

    def get_modules(self, model_constants, loss_params):
        return [
            SaveGroundDeformPixels(model_constants, loss_params, self.height_range),
            DrawBEV(model_constants, loss_params, self.height_range),
        ]

    def __call__(self, bev, ground_featuremaps, model_input, stride, iteration, name, config): # TODO: move ground features computation into this class?
        print("Ground attention block")
        ground_attn_stride = int(self.model_constants.ground_attn_strides[iteration])

        if self.fix_transform:
            assert False
            if iteration == 0:
                # Predict up vectors in camera coordinates and height in ego coordinates
                x = tfcv.model.einops.apply("b c s... f -> b c f", ground_featuremaps[-2], reduction="mean")
                x = tfcv.model.util.conv_act(x, filters=64, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join("ground", f"fix-transform-{iteration + 1}", "1"), config=config)
                x = tfcv.model.util.conv(x, filters=4, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join("ground", f"fix-transform-{iteration + 1}", "2"), config=config) # b c 4
                layer = x.node.layer
                layer.set_weights([layer.get_weights()[0] * 0, layer.get_weights()[1] * 0])
                up_vectors = x[..., :3]
                height = x[..., 3]

                height = tf.reduce_mean(height, axis=1) # Reduce camera axis

                # Transform up vectors to ego coordinates
                cam_to_ego = tf.transpose(model_input.ground_ego_to_cam[:, :, :3, :3], (0, 1, 3, 2))
                up_vectors = (cam_to_ego @ up_vectors[..., tf.newaxis])[..., 0] # b c 3

                # Normalize with default up vector
                default_up_vector = tfcv.model.einops.apply("i0 -> b 1 i0", tf.cast(tf.convert_to_tensor([0, 0, tf.shape(model_input.ground_images)[0] * 0 + 1]), "float32"), b=model_input.batches)
                default_up_vector = tfcv.model.util.ScaleLayer(axis=[], initial_value=1.0)(default_up_vector)
                up_vectors = tf.concat([up_vectors, default_up_vector], axis=1)
                up_vectors = tf.reduce_sum(up_vectors, axis=1)
                up_vectors = tf.linalg.l2_normalize(up_vectors, axis=-1)
            else:
                # Predict up vectors and height in ego coordinates
                x = tfcv.model.einops.apply("b s... f -> b f", bev, reduction="mean")
                x = tfcv.model.util.conv_act(x, filters=64, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join("ground", f"fix-transform-{iteration + 1}", "1"), config=config)
                x = tfcv.model.util.conv(x, filters=4, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join("ground", f"fix-transform-{iteration + 1}", "2"), config=config) # b c 3
                layer = x.node.parent_nodes[0].layer
                layer.set_weights([layer.get_weights()[0] * 0, np.asarray([0.0, 0.0, 1.0, 0.0])])
                up_vectors = x[..., :3]
                height = x[..., 3]

                up_vectors = tf.linalg.l2_normalize(up_vectors, axis=-1)

            if not self.last_up_vectors is None:
                up_vectors = 0.5 * (self.last_up_vectors + up_vectors)
                height = 0.5 * (self.last_height + height)
            self.last_up_vectors = up_vectors
            self.last_height = height

            # Find transformation to previous ego
            translation = tf.stack([tf.zeros_like(height), tf.zeros_like(height), height], axis=-1) # b 3

            up_vectors_ego = tfcv.model.einops.apply("i0 -> b i0", np.asarray([0.0, 0.0, 1.0]).astype("float32"), b=model_input.batches)
            up_vectors_fixedego = up_vectors
            # rotation = tensorflow_graphics.geometry.transformation.quaternion.between_two_vectors_3d(up_vectors_ego, up_vectors_fixedego)
            # rotation = tensorflow_graphics.geometry.transformation.rotation_matrix_3d.from_quaternion(rotation) # b 3 3

            ego_to_fixedego = tf.concat([rotation, translation[:, :, tf.newaxis]], axis=2) # b 3 4
            def asd(ego_to_fixedego):
                return georeg.model.util.pad_matrix(ego_to_fixedego, rank=4)
            ego_to_fixedego = tf.keras.layers.Lambda(asd)(ego_to_fixedego) # b 4 4
        else:
            ego_to_fixedego = None



        ground_attn_out = bev
        # ground_attn_out = georeg.model.util.backbone.resize(ground_attn_out, stride, ground_attn_stride, name=tfcv.model.util.join(name, "bev-into-ground-attn"), config=config)

        if self.last_ground_point_grid is None or ground_attn_stride != self.last_ground_point_grid.stride or not ego_to_fixedego is None:
            print("Creating new ground point grid")
            ground_point_grid = PointGrid3D(
                stride=ground_attn_stride,
                bev_shape=self.model_constants.bev_shapes[-1] // ground_attn_stride,
                meters_per_pixel=self.model_constants.meters_per_pixel[-1] * ground_attn_stride,
                height_range=self.height_range,
                model_input=model_input,
                ego_to_fixedego=ego_to_fixedego,
                iteration=iteration,
                model_constants=self.model_constants,
                model_params=self.model_params[f"ground-point-grid-{iteration}"],
            )
            self.last_ground_point_grid = ground_point_grid
        else:
            ground_point_grid = self.last_ground_point_grid

        ground_attn_out = self.ground_attn(
            ground_attn_out,
            ground_gather=self.ground_gather,
            ground_featuremaps=ground_featuremaps,
            ground_point_grid=ground_point_grid,
            model_input=model_input,
            name=name,
            iteration=iteration,
            config=config,
        )
        # ground_attn_out = georeg.model.util.backbone.resize(ground_attn_out, ground_attn_stride, stride, name=tfcv.model.util.join(name, "bev-outof-ground-attn"), config=config) # b s... f


        bev_mask = ground_point_grid.bev_mask
        ground_attn_out = tf.where(bev_mask[..., tf.newaxis], ground_attn_out, 0.0)
        if iteration == self.model_constants.blocks - 1:
            self.last_bev_mask_origsize = bev_mask
            bev_shape = tf.shape(bev_mask)[-2:] * ground_attn_stride // stride
            bev_mask = tf.math.greater(tfcv.model.util.resize(tf.where(bev_mask[..., tf.newaxis], 1.0, 0.0), bev_shape, method="bilinear", config=config)[..., 0], 0.0)
            self.last_bev_mask = bev_mask

        return ground_attn_out
