import tensorflow as tf
import math

def rotation_matrix_to_angle(rotation_matrix):
    return tf.math.atan2(rotation_matrix[..., 1, 0], rotation_matrix[..., 0, 0])

def angle_to_rotation_matrix(angle):
    if not tf.is_tensor(angle):
        angle = tf.convert_to_tensor(angle)
    row0 = tf.stack([tf.cos(angle), -tf.sin(angle)], axis=-1)[..., tf.newaxis, :] # [batch..., 1, cols]
    row1 = tf.stack([tf.sin(angle), tf.cos(angle)], axis=-1)[..., tf.newaxis, :] # [batch..., 1, cols]
    rotation_matrix = tf.concat([row0, row1], axis=-2) # [batch..., rows, cols]
    assert len(angle.shape) + 2 == len(rotation_matrix.shape) or (len(angle.shape <= 1) and len(rotation_matrix.shape) == 2)
    return rotation_matrix

class Rigid:
    @staticmethod
    def create(rotation=None, translation=None, dtype=None, rank=None, batch_shape=None):
        if not rotation is None and not tf.is_tensor(rotation):
            rotation = tf.convert_to_tensor(rotation)
        if not translation is None and not tf.is_tensor(translation):
            translation = tf.convert_to_tensor(translation)

        if dtype is None:
            if not rotation is None:
                dtype = rotation.dtype
            elif not translation is None:
                dtype = translation.dtype
            else:
                raise ValueError("Expected dtype argument when rotation and translation are not given")

        if rank is None:
            if not translation is None:
                rank = tf.shape(translation)[-1]
            elif not rotation is None:
                rank = tf.shape(rotation)[-1]
            else:
                raise ValueError("Expected rank argument when rotation and translation are not given")

        if batch_shape is None:
            if not translation is None:
                batch_shape = tf.shape(translation)[:-1]
            elif not rotation is None:
                batch_shape = tf.shape(rotation)[:-2]
            else:
                raise ValueError("Expected batch_shape argument when rotation and translation are not given")

        if translation is None:
            translation = tf.zeros(tf.concat([batch_shape, [rank]], axis=0), dtype=dtype)
        if rotation is None:
            rotation = tf.eye(rank, batch_shape=batch_shape, dtype=dtype)

        rotation = tf.cast(rotation, dtype)
        translation = tf.cast(translation, dtype)

        tf.debugging.assert_equal(tf.shape(rotation), tf.concat([batch_shape, [rank, rank]], axis=0), message=f"Got invalid rotation shape {rotation.shape}")
        tf.debugging.assert_equal(tf.shape(translation), tf.concat([batch_shape, [rank]], axis=0), message=f"Got invalid translation shape {translation.shape}")
        tf.debugging.assert_all_finite(rotation, message="Got invalid rotation in Rigid")
        tf.debugging.assert_all_finite(translation, message="Got invalid translation in Rigid")
        tf.debugging.assert_near(tf.matmul(rotation, rotation, transpose_b=True), tf.eye(rank, batch_shape=batch_shape, dtype=dtype), message=f"Got non-orthogonal rotation in Rigid")

        return (rotation, translation)

    @staticmethod
    def rotation(transform):
        return transform[0]

    @staticmethod
    def translation(transform):
        return transform[1]

    @staticmethod
    def dtype(transform):
        return Rigid.translation(transform).dtype

    @staticmethod
    def batch_shape(transform):
        return tf.shape(Rigid.translation(transform))[:-1]

    @staticmethod
    def rank(transform):
        return tf.shape(Rigid.translation(transform))[-1]

    @staticmethod
    def slice(transform, *slices):
        if len(slices) > len(Rigid.batch_shape(transform)):
            raise ValueError("Too many slice dimensions")
        return Rigid.create(
            rotation=Rigid.rotation(transform)[slices],
            translation=Rigid.translation(transform)[slices],
        )

    @staticmethod
    def transform(transform, points):
        # points: [batch..., points, dims]
        tf.debugging.assert_equal(Rigid.batch_shape(transform), tf.shape(points)[:-2], "Arguments do not have the same batch shape in Rigid.transform")
        tf.debugging.assert_equal(Rigid.rank(transform), tf.shape(points)[-1], "Arguments do not have matching dimensions in Rigid.transform")

        points = tf.linalg.matrix_transpose(tf.linalg.matmul(Rigid.rotation(transform), points, transpose_b=True))
        points = points + Rigid.translation(transform)[..., tf.newaxis, :]

        return points

    @staticmethod
    def multiply(transform1, transform2):
        rotation = tf.linalg.matmul(Rigid.rotation(transform1), Rigid.rotation(transform2))
        translation = Rigid.transform(transform1, Rigid.translation(transform2)[..., tf.newaxis, :])[..., 0, :]
        return Rigid.create(rotation, translation)

    @staticmethod
    def inverse(transform):
        inv_rotation = tf.linalg.matrix_transpose(Rigid.rotation(transform))
        inv_translation = tf.linalg.matmul(inv_rotation, -Rigid.translation(transform)[..., :, tf.newaxis])[..., 0]
        return Rigid.create(inv_rotation, inv_translation)

    @staticmethod
    def divide(transform1, transform2):
        return Rigid.multiply(transform1, Rigid.inverse(transform2))

    @staticmethod
    def cast(transform, dtype):
        return Rigid.create(tf.cast(Rigid.rotation(transform), dtype), tf.cast(Rigid.translation(transform), dtype=dtype))

    @staticmethod
    def random2d(translation_stddev, angle_stddev, batch_shape=(), dtype="float32"):
        translation_angle = tf.random.uniform(batch_shape, 0.0, math.pi, dtype=dtype) # [batch...]
        translation_distance = tf.random.normal(batch_shape, mean=0.0, stddev=translation_stddev, dtype=dtype) # [batch...]

        translation = tf.stack([translation_distance, tf.broadcast_to([0.0], batch_shape)], axis=-1) # [batch..., 2]
        translation = translation[..., tf.newaxis] # [batch..., 2, 1]
        translation = tf.linalg.matmul(angle_to_rotation_matrix(translation_angle), translation) # [batch..., 2, 1]
        translation = translation[..., :, 0] # [batch..., 2]

        angle = tf.random.normal(batch_shape, mean=0.0, stddev=angle_stddev, dtype=dtype)
        rotation_matrix = angle_to_rotation_matrix(angle)
        return Rigid.create(rotation_matrix, translation, dtype=dtype, rank=2, batch_shape=batch_shape)

    @staticmethod
    def str(transform):
        return f"{{t={Rigid.translation(transform).numpy().tolist()} R={Rigid.rotation(transform).numpy().tolist()}}}"

    @staticmethod
    def from_matrix(matrix):
        # matrix: [batch..., dims, dims]
        rotation = matrix[..., :-1, :-1]
        translation = matrix[..., :-1, -1]
        return Rigid.create(rotation, translation)




class ScaledRigid:
    @staticmethod
    def create(rotation=None, translation=None, scale=None, dtype=None, rank=None, batch_shape=None):
        if not rotation is None and not tf.is_tensor(rotation):
            rotation = tf.convert_to_tensor(rotation)
        if not translation is None and not tf.is_tensor(translation):
            scale = tf.convert_to_tensor(scale)

        if dtype is None:
            if not rotation is None:
                dtype = rotation.dtype
            elif not translation is None:
                dtype = translation.dtype
            elif not scale is None:
                dtype = scale.dtype
            else:
                raise ValueError("Expected dtype argument when rotation, translation and scale are not given")

        if rank is None:
            if not translation is None:
                rank = tf.shape(translation)[-1]
            elif not rotation is None:
                rank = tf.shape(rotation)[-1]
            else:
                raise ValueError("Expected rank argument when rotation and translation are not given")

        if batch_shape is None:
            if not translation is None:
                batch_shape = tf.shape(translation)[:-1]
            elif not rotation is None:
                batch_shape = tf.shape(rotation)[:-2]
            else:
                raise ValueError("Expected batch_shape argument when rotation and translation are not given")

        if translation is None:
            translation = tf.zeros(tf.concat([batch_shape, [rank]], axis=0), dtype=dtype)
        if rotation is None:
            rotation = tf.eye(rank, batch_shape=batch_shape, dtype=dtype)
        if scale is None:
            scale = tf.ones(batch_shape, dtype=dtype)

        rotation = tf.cast(rotation, dtype)
        translation = tf.cast(translation, dtype)
        scale = tf.cast(scale, dtype)

        tf.debugging.assert_equal(tf.shape(rotation), tf.concat([batch_shape, [rank, rank]], axis=0), message=f"Got invalid rotation shape {rotation.shape}")
        tf.debugging.assert_equal(tf.shape(translation), tf.concat([batch_shape, [rank]], axis=0), message=f"Got invalid translation shape {translation.shape}")
        tf.debugging.assert_equal(tf.shape(scale), batch_shape, message=f"Got invalid scale shape {scale.shape}")
        tf.debugging.assert_all_finite(rotation, message="Got invalid rotation")
        tf.debugging.assert_all_finite(translation, message="Got invalid translation")
        tf.debugging.assert_all_finite(translation, message="Got invalid scale")
        tf.debugging.assert_near(tf.matmul(rotation, rotation, transpose_b=True), tf.eye(rank, batch_shape=batch_shape, dtype=dtype), message=f"Got non-orthogonal rotation in Rigid")

        return (rotation, translation, scale)

    @staticmethod
    def rotation(transform):
        return transform[0]

    @staticmethod
    def translation(transform):
        return transform[1]

    @staticmethod
    def scale(transform):
        return transform[2]

    @staticmethod
    def dtype(transform):
        return ScaledRigid.translation(transform).dtype

    @staticmethod
    def batch_shape(transform):
        return tf.shape(ScaledRigid.translation(transform))[:-1]

    @staticmethod
    def rank(transform):
        return tf.shape(ScaledRigid.translation(transform))[-1]

    @staticmethod
    def slice(transform, *slices):
        if len(slices) > len(ScaledRigid.batch_shape(transform)):
            raise ValueError("Too many slice dimensions")
        return ScaledRigid.create(
            rotation=ScaledRigid.rotation(transform)[slices],
            translation=ScaledRigid.translation(transform)[slices],
            scale=ScaledRigid.scale(transform)[slices],
        )

    @staticmethod
    def transform(transform, points):
        # points: [batch..., points, dims]
        tf.debugging.assert_equal(ScaledRigid.batch_shape(transform), tf.shape(points)[:-2], "Arguments do not have the same batch shape")
        tf.debugging.assert_equal(ScaledRigid.rank(transform), tf.shape(points)[-1], "Arguments do not have matching dimensions")

        points = tf.linalg.matrix_transpose(tf.linalg.matmul(ScaledRigid.scale(transform)[..., tf.newaxis, tf.newaxis] * ScaledRigid.rotation(transform), points, transpose_b=True))
        points = points + ScaledRigid.translation(transform)[..., tf.newaxis, :]

        return points

    @staticmethod
    def multiply(transform1, transform2):
        rotation = tf.linalg.matmul(ScaledRigid.rotation(transform1), ScaledRigid.rotation(transform2))
        translation = ScaledRigid.transform(transform1, ScaledRigid.translation(transform2)[..., tf.newaxis, :])[..., 0, :]
        scale = ScaledRigid.scale(transform1) * ScaledRigid.scale(transform2)
        return ScaledRigid.create(rotation, translation)

    @staticmethod
    def inverse(transform):
        inv_rotation = tf.linalg.matrix_transpose(ScaledRigid.rotation(transform))
        inv_translation = tf.linalg.matmul(inv_rotation, -ScaledRigid.translation(transform)[..., :, tf.newaxis])[..., 0] / ScaledRigid.scale(transform)[..., tf.newaxis]
        inv_scale = 1 / ScaledRigid.scale(transform)
        return ScaledRigid.create(inv_rotation, inv_translation)

    @staticmethod
    def divide(transform1, transform2):
        return ScaledRigid.multiply(transform1, ScaledRigid.inverse(transform2))

    @staticmethod
    def cast(transform, dtype):
        return ScaledRigid.create(
            tf.cast(ScaledRigid.rotation(transform), dtype=dtype),
            tf.cast(ScaledRigid.translation(transform), dtype=dtype),
            tf.cast(ScaledRigid.scale(transform), dtype=dtype),
        )

    @staticmethod
    def str(transform):
        return f"{{t={ScaledRigid.translation(transform).numpy().tolist()} R={ScaledRigid.rotation(transform).numpy().tolist()} s={ScaledRigid.scale(transform).numpy().tolist()}}}"
