import tensorflow as tf
import numpy as np
import random, tfcv, math, georeg, random

class ModelInput:
    def __init__(self, aerial_image, bev_pixels, bev_pixels_mask, ground_images, ground_images_mask, ground_images_shape, ground_images_num, ground_ego_to_cam, ground_ego_to_pixels, ground_intr, ground_pixels, ground_pixels_mask, ground_depths, angles, corr_shape, training_mask):
        self.aerial_image_i = aerial_image
        self.bev_pixels_i = bev_pixels
        self.bev_pixels_mask_i = bev_pixels_mask
        self.ground_images_i = ground_images
        self.ground_images_mask_i = ground_images_mask
        self.ground_images_shape_i = ground_images_shape
        self.ground_images_num_i = ground_images_num
        self.ground_ego_to_cam_i = ground_ego_to_cam
        self.ground_ego_to_pixels_i = ground_ego_to_pixels
        self.ground_intr_i = ground_intr
        self.ground_pixels_i = ground_pixels
        self.ground_pixels_mask_i = ground_pixels_mask
        self.ground_depths_i = ground_depths
        self.angles_i = angles
        self.corr_shape_i = corr_shape
        self.training_mask_i = training_mask

    def to_list(self, preprocess_aerial=lambda x: x, preprocess_ground=lambda x: x):
        return [preprocess_aerial(self.aerial_image_i), self.bev_pixels_i, self.bev_pixels_mask_i, preprocess_ground(self.ground_images_i), self.ground_images_mask_i, self.ground_images_shape_i, self.ground_images_num_i, self.ground_ego_to_cam_i, self.ground_ego_to_pixels_i, self.ground_intr_i, self.ground_pixels_i, self.ground_pixels_mask_i, self.ground_depths_i, self.angles_i, self.corr_shape_i, self.training_mask_i]

    aerial_image = property(lambda self: self.aerial_image_i)
    bev_pixels = property(lambda self: self.bev_pixels_i)
    bev_pixels_mask = property(lambda self: self.bev_pixels_mask_i)
    ground_images = property(lambda self: self.ground_images_i)
    ground_images_mask = property(lambda self: self.ground_images_mask_i)
    ground_images_shape = property(lambda self: self.ground_images_shape_i)
    ground_images_num = property(lambda self: self.ground_images_num_i)
    ground_ego_to_cam = property(lambda self: self.ground_ego_to_cam_i)
    ground_ego_to_pixels = property(lambda self: self.ground_ego_to_pixels_i)
    ground_intr = property(lambda self: self.ground_intr_i)
    ground_pixels = property(lambda self: self.ground_pixels_i)
    ground_pixels_mask = property(lambda self: self.ground_pixels_mask_i)
    ground_depths = property(lambda self: self.ground_depths_i)
    angles = property(lambda self: self.angles_i)
    corr_shape = property(lambda self: self.corr_shape_i[0]) # Same for all batches
    training_mask = property(lambda self: self.training_mask_i[0]) # Same for all batches

    batches = property(lambda self: tf.shape(self.aerial_image_i)[0])
    max_cameras = property(lambda self: tf.shape(self.ground_images_i)[1])
    points = property(lambda self: tf.shape(self.ground_pixels_i)[2])

    @staticmethod
    def keras():
        aerial_image_i          = tf.keras.Input((None, None, 3),       dtype="float32",    name="aerial_image")
        bev_pixels_i            = tf.keras.Input((None, 2),             dtype="float32",    name="bev_pixels") # [batch, points, 2]
        bev_pixels_mask_i       = tf.keras.Input((None,),               dtype="bool",       name="bev_pixels_mask") # [batch, points]
        ground_images_i         = tf.keras.Input((None, None, None, 3), dtype="float32",    name="ground_images") # [batch, cameras, dims..., rgb]
        ground_images_mask_i    = tf.keras.Input((None, None, None),    dtype="bool",       name="ground_images_mask") # [batch, cameras, dims...]
        ground_images_shape_i   = tf.keras.Input((None, 2),             dtype="int32",      name="ground_images_shape") # [batch, cameras, 2]
        ground_images_num_i     = tf.keras.Input((),                    dtype="int32",      name="ground_images_num") # [batch]
        ground_ego_to_cam_i     = tf.keras.Input((None, 4, 4),          dtype="float32",    name="ground_ego_to_cam") # [batch, cameras, 4, 4]
        ground_ego_to_pixels_i  = tf.keras.Input((4, 4),                dtype="float32",    name="ground_ego_to_pixels") # [batch, 4, 4]
        ground_intr_i           = tf.keras.Input((None, 3, 3),          dtype="float32",    name="ground_intr") # [batch, cameras, 3, 3]
        ground_pixels_i         = tf.keras.Input((None, None, 2),       dtype="float32",    name="ground_pixels") # [batch, cameras, points, 2]
        ground_pixels_mask_i    = tf.keras.Input((None, None,),         dtype="bool",       name="ground_pixels_mask") # [batch, cameras, points]
        ground_depths_i         = tf.keras.Input((None, None,),         dtype="float32",    name="ground_depths") # [batch, cameras, points]
        angles_i                = tf.keras.Input((None,),               dtype="float32",    name="angles") # [batch, angles]
        corr_shape_i            = tf.keras.Input((2,),                  dtype="int32",      name="corr_shape") # [batch, 2]
        training_mask_i         = tf.keras.Input((),                    dtype="bool",       name="training_mask") # [batch]

        return ModelInput(aerial_image_i, bev_pixels_i, bev_pixels_mask_i, ground_images_i, ground_images_mask_i, ground_images_shape_i, ground_images_num_i, ground_ego_to_cam_i, ground_ego_to_pixels_i, ground_intr_i, ground_pixels_i, ground_pixels_mask_i, ground_depths_i, angles_i, corr_shape_i, training_mask_i)

    @staticmethod
    def tf_signature(points_num):
        return [
            tf.TensorSpec([None, None, None, 3],                                                dtype="float32",    name="aerial_image"),
            tf.TensorSpec([None, points_num, 2],                                                dtype="float32",    name="bev_pixels"),
            tf.TensorSpec([None, points_num],                                                   dtype="bool",       name="bev_pixels_mask"),
            tf.TensorSpec([None, None, None, None, 3],                                          dtype="float32",    name="ground_images"),
            tf.TensorSpec([None, None, None, None],                                             dtype="bool",       name="ground_images_mask"),
            tf.TensorSpec([None, None, 2],                                                      dtype="int32",      name="ground_images_shape"),
            tf.TensorSpec([None],                                                               dtype="int32",      name="ground_images_num"),
            tf.TensorSpec([None, None, 4, 4],                                                   dtype="float32",    name="ground_ego_to_cam"),
            tf.TensorSpec([None, 4, 4],                                                         dtype="float32",    name="ground_ego_to_pixels"),
            tf.TensorSpec([None, None, 3, 3],                                                   dtype="float32",    name="ground_intr"),
            tf.TensorSpec([None, None, points_num, 2],                                          dtype="float32",    name="ground_pixels"),
            tf.TensorSpec([None, None, points_num],                                             dtype="bool",       name="ground_pixels_mask"),
            tf.TensorSpec([None, None, points_num],                                             dtype="float32",    name="ground_depths"),
            tf.TensorSpec([None, None],                                                         dtype="float32",    name="angles"),
            tf.TensorSpec([None, 2],                                                            dtype="int32",      name="corr_shape"),
            tf.TensorSpec([None],                                                               dtype="bool",       name="training_mask"),
        ]

    @staticmethod
    def from_frames(frames, angles, align_bearing, shuffle_points, points_num, corr_shape, training_mask, augment_aerial_image=None, augment_ground_image=None):
        if augment_aerial_image is None:
            augment_aerial_image = lambda x: x
        if augment_ground_image is None:
            augment_ground_image = lambda x: x

        batchsize = len(frames)
        angles = np.asarray(angles)

        cameras_num = max([len(frame.ground_frame.cameras) for frame in frames])
        max_ground_image_size = np.amax([camera.image.shape[:2] for frame in frames for camera in frame.ground_frame.cameras], axis=0)
        multiplier = 32
        max_ground_image_size = (max_ground_image_size + multiplier - 1) // multiplier * multiplier

        aerial_images = []
        bev_pixels = []
        bev_pixels_mask = []
        ground_images = []
        ground_images_mask = []
        ground_images_shape = []
        ground_images_num = []
        ground_ego_to_cam = []
        ground_ego_to_pixels = []
        ground_intr = []
        ground_pixels = []
        ground_pixels_mask = []
        ground_depths = []
        new_angles = []
        for frame in frames:
            points_num_out = points_num
            points_num_in = frame.ground_frame.points_num

            if points_num_in < points_num_out:
                indices = (np.linspace(0, float(points_num_in - 1), num=points_num_in) * float(points_num_out) / float(points_num_in) + 1e-2).astype("int32")
            elif points_num_in > points_num_out:
                indices = (np.linspace(0, float(points_num_out - 1), num=points_num_out) * float(points_num_in) / float(points_num_out) + 1e-2).astype("int32")
            else:
                indices = np.arange(points_num_out)

            if shuffle_points:
                np.random.shuffle(indices)

            def ensure_points_num(tensor, fill):
                if points_num_in < points_num_out:
                    tensor2 = np.full((points_num_out, *tensor.shape[1:]), fill, dtype=tensor.dtype)
                    tensor2[indices] = tensor
                    tensor = tensor2

                elif points_num_in > points_num_out:
                    tensor = tensor[indices]

                assert tensor.shape[0] == points_num_out
                return tensor

            if align_bearing:
                angle_offset = math.radians(frame.aerial_frame.bearing - frame.ground_frame.bearing) + random.choice(angles)
            else:
                angle_offset = 0.0
            new_angles.append(angles + angle_offset)

            aerial_images.append(augment_aerial_image(frame.aerial_frame.image)) # Aerial images always have the same shape
            bev_pixels.append(ensure_points_num(frame.ground_frame.bev_pixels, 0.0))
            bev_pixels_mask.append(ensure_points_num(np.ones([frame.ground_frame.bev_pixels.shape[0]], dtype="bool"), False))

            frame_ground_images = []
            frame_ground_images_mask = []
            frame_ground_images_shape = []
            frame_ground_ego_to_cam = []
            frame_ground_intr = []
            frame_ground_pixels = []
            frame_ground_pixels_mask = []
            frame_ground_depths = []
            for camera in frame.ground_frame.cameras:
                frame_ground_images.append(tfcv.image.transform.pad(max_ground_image_size, location="topleft", ndim=2)((augment_ground_image(camera.image), "color"))) # Ground images can have different shapes
                frame_ground_images_mask.append(tfcv.image.transform.pad(max_ground_image_size, location="topleft", ndim=2)((camera.image_mask, "mask"))) # Ground images can have different shapes
                frame_ground_images_shape.append(camera.image.shape[:2])
                frame_ground_ego_to_cam.append(camera.ego_to_camera.to_matrix())
                frame_ground_intr.append(camera.intr)
                frame_ground_pixels.append(ensure_points_num(camera.pixels, 0.0))
                frame_ground_pixels_mask.append(ensure_points_num(camera.points_mask, False))
                frame_ground_depths.append(ensure_points_num(camera.points_depth, -1.0))
            for _ in range(cameras_num - len(frame.ground_frame.cameras)):
                frame_ground_images.append(frame_ground_images[0])
                frame_ground_images_mask.append(np.logical_and(frame_ground_images_mask[0], False))
                frame_ground_images_shape.append(frame_ground_images_shape[0])
                frame_ground_ego_to_cam.append(frame_ground_ego_to_cam[0])
                frame_ground_intr.append(frame_ground_intr[0])
                frame_ground_pixels.append(frame_ground_pixels[0])
                frame_ground_pixels_mask.append(frame_ground_pixels_mask[0])
                frame_ground_depths.append(frame_ground_depths[0])
            ground_images.append(frame_ground_images)
            ground_images_mask.append(frame_ground_images_mask)
            ground_images_shape.append(frame_ground_images_shape)
            ground_images_num.append(len(frame.ground_frame.cameras))
            ground_ego_to_cam.append(frame_ground_ego_to_cam)
            ground_intr.append(frame_ground_intr)
            ground_pixels.append(frame_ground_pixels)
            ground_pixels_mask.append(frame_ground_pixels_mask)
            ground_depths.append(frame_ground_depths)
            ground_ego_to_pixels.append(frame.ground_frame.ego_to_pixels.to_matrix())

        return ModelInput(
            aerial_image=np.asarray(aerial_images).astype("float32"),
            bev_pixels=np.asarray(bev_pixels).astype("float32"),
            bev_pixels_mask=np.asarray(bev_pixels_mask).astype("bool"),
            ground_images=np.asarray(ground_images).astype("float32"),
            ground_images_mask=np.asarray(ground_images_mask).astype("bool"),
            ground_images_shape=np.asarray(ground_images_shape).astype("int32"),
            ground_images_num=np.asarray(ground_images_num).astype("int32"),
            ground_ego_to_cam=np.asarray(ground_ego_to_cam).astype("float32"),
            ground_ego_to_pixels=np.asarray(ground_ego_to_pixels).astype("float32"),
            ground_intr=np.asarray(ground_intr).astype("float32"),
            ground_pixels=np.asarray(ground_pixels).astype("float32"),
            ground_pixels_mask=np.asarray(ground_pixels_mask).astype("bool"),
            ground_depths=np.asarray(ground_depths).astype("float32"),
            angles=np.asarray(new_angles).astype("float32"),
            corr_shape=np.asarray([corr_shape] * batchsize).astype("int32"),
            training_mask=np.asarray([training_mask] * batchsize).astype("bool"),
        )

class LossInput:
    def __init__(self, bevpixels_to_aerialpixels):
        self.bevpixels_to_aerialpixels_i = bevpixels_to_aerialpixels

    def to_list(self):
        return [self.bevpixels_to_aerialpixels_i]

    def get_bevpixels_to_aerialpixels_with_shape(self, bev_shape, aerial_shape, model_constants):
        bev_factor = tf.cast(model_constants.bev_shapes[-1], "float32") / tf.cast(bev_shape[0], "float32")
        bevpixels_to_bevimagepixels = georeg.model.util.pad_matrix(
            tf.eye(2, dtype="float32") * bev_factor,
            rank=3,
        )
        aerial_factor = tf.cast(aerial_shape[0], "float32") / tf.cast(model_constants.aerial_image_shape, "float32")
        aerialimagepixels_to_aerialpixels = georeg.model.util.pad_matrix(
            tf.eye(2, dtype="float32") * aerial_factor,
            rank=3,
        )

        return aerialimagepixels_to_aerialpixels @ self.bevpixels_to_aerialpixels_i @ bevpixels_to_bevimagepixels

    batches = property(lambda self: tf.shape(self.bevpixels_to_aerialpixels_i)[0])

    @staticmethod
    def keras():
        bevpixels_to_aerialpixels_i  = tf.keras.Input((3, 3), dtype="float32", name="bevpixels_to_aerialpixels")

        return LossInput(bevpixels_to_aerialpixels_i)

    @staticmethod
    def tf_signature():
        return [
            tf.TensorSpec([None, 3, 3], dtype="float32", name="bevpixels_to_aerialpixels"),
        ]

    @staticmethod
    def from_frames(frames):
        batchsize = len(frames)

        bevpixels_to_aerialpixels = []
        for frame in frames:
            bevpixels_to_aerialpixels.append(frame.bevpixels_to_aerialpixels.to_matrix())

        return LossInput(
            bevpixels_to_aerialpixels=np.asarray(bevpixels_to_aerialpixels).astype("float32"),
        )
