import numpy as np
import tensorflow as tf
import math
from pyquaternion import Quaternion

def rotation_matrix_to_angle(rotation_matrix):
    return math.atan2(rotation_matrix[1, 0], rotation_matrix[0, 0])

def angle_to_rotation_matrix(angle):
    return np.stack([
        np.stack([math.cos(angle), -math.sin(angle)], axis=0),
        np.stack([math.sin(angle), math.cos(angle)], axis=0)
    ], axis=0)

def angle(v1, v2, clockwise=False):
    v1 = np.asarray(v1)
    v2 = np.asarray(v2)

    angle = math.atan2(v2[1], v2[0]) - math.atan2(v1[1], v1[0])

    if angle > math.pi:
        angle = angle - 2 * math.pi
    if angle <= -math.pi:
        angle = angle + 2 * math.pi

    return -angle if clockwise else angle

# TODO: make ScaledRigid an extension of Rigid to avoid code duplication, put this as staticmethod?
def project_2d_to_3d(transform, axes=[0, 1]):
    if isinstance(transform, Rigid):
        rotation = np.eye(3, dtype=transform.dtype())
        for xy in [[0, 0], [0, 1], [1, 0], [1, 1]]:
            rotation[axes[xy[0]], axes[xy[1]]] = transform.rotation[xy[0], xy[1]]
        translation = np.zeros((3,), dtype=transform.dtype())
        translation[axes[0]] = transform.translation[0]
        translation[axes[1]] = transform.translation[1]
        return Rigid(rotation=rotation, translation=translation)
    elif isinstance(transform, ScaledRigid):
        rotation = np.eye(3, dtype=transform.dtype())
        for xy in [[0, 0], [0, 1], [1, 0], [1, 1]]:
            rotation[axes[xy[0]], axes[xy[1]]] = transform.rotation[xy[0], xy[1]]
        translation = np.zeros((3,), dtype=transform.dtype())
        translation[axes[0]] = transform.translation[0]
        translation[axes[1]] = transform.translation[1]
        return ScaledRigid(rotation=rotation, translation=translation, scale=transform.scale)
    else:
        raise ValueError("Invalid transform")

def rotate_points(points, angle, origin=0):
    return origin + Rigid(rotation=angle_to_rotation_matrix(angle))(points - origin)

def apply_matrix(matrix, points):
    if matrix.shape[0] == points.shape[1]:
        m = np.eye(matrix.shape[0] + 1, dtype=matrix.dtype)
        m[:matrix.shape[0], :matrix.shape[1]] = matrix
        matrix = m

    assert matrix.shape[0] - 1 == points.shape[1]
    rank = points.shape[1]

    if isinstance(points, int) or isinstance(points, float):
        assert points == 0
        points = np.asarray([points] * rank)
    elif isinstance(points, tuple) or isinstance(points, list):
        points = np.asarray(points)

    single_point = len(points.shape) == 1
    if single_point:
        points = points[np.newaxis, :]

    points = np.insert(points, rank, 1, axis=1)
    points = np.transpose(np.matmul(matrix, np.transpose(points, (1, 0))), (1, 0))
    points = points[:, :-1] # TODO: divice?

    if single_point:
        points = points[0, :]
    return points

def lerp(x, xs, ys, lerp2=lambda y1, y2, amount: (1 - amount) * y1 + amount * y2):
    xs = np.asarray(xs)
    if not np.all(xs[:-1] < xs[1:]):
        raise ValueError("xs must be strictly ordered")
    if xs.shape[0] != len(ys):
        raise ValueError("xs and ys must have same number of items")
    index = int(np.argmin(np.abs(xs - x)))
    target_x = xs[index]
    if index == 0 and x <= target_x:
        return ys[0]
    elif index == xs.shape[0] - 1 and x >= target_x:
        return ys[-1]
    else:
        if x < target_x:
            index1 = index - 1
            index2 = index
        else:
            index1 = index
            index2 = index + 1
        x1 = xs[index1]
        x2 = xs[index2]
        amount = (x - x1) / (x2 - x1)
        assert 0 <= amount and amount <= 1
        return lerp2(ys[index1], ys[index2], amount=amount)

class Rigid:
    def __init__(self, rotation=None, translation=None, dtype=None, rank=None):
        if isinstance(rotation, float) or isinstance(rotation, int):
            rotation = angle_to_rotation_matrix(float(rotation))
        if not rotation is None:
            rotation = np.asarray(rotation)
        if not translation is None:
            translation = np.asarray(translation)

        if dtype is None:
            if not rotation is None:
                dtype = rotation.dtype
            elif not translation is None:
                dtype = translation.dtype
            else:
                raise ValueError("Expected dtype argument when rotation and translation are not given")

        if rank is None:
            if not translation is None:
                rank = translation.shape[0]
            elif not rotation is None:
                rank = rotation.shape[0]
            else:
                raise ValueError("Expected rank argument when rotation and translation are not given")

        if translation is None:
            translation = np.zeros((rank,), dtype=dtype)
        if rotation is None:
            rotation = np.eye(rank, dtype=dtype)
        if rotation.shape != (rank, rank):
            raise ValueError(f"Got invalid rotation shape {rotation.shape}")
        if translation.shape != (rank,):
            raise ValueError(f"Got invalid translation shape {translation.shape}")
        if not np.all(np.isfinite(rotation)):
            raise ValueError(f"Got non-finite rotation {rotation}")
        if not np.all(np.isfinite(translation)):
            raise ValueError(f"Got non-finite translation {translation}")
        if not np.allclose(np.matmul(rotation, np.transpose(rotation, (1, 0))), np.eye(rank), atol=1e-6):
            error = np.sum(np.matmul(rotation, np.transpose(rotation, (1, 0))) - np.eye(rank))
            raise ValueError(f"Got non-orthogonal rotation {rotation} with error={error} and dtype {dtype}")

        self.rotation = rotation.astype(dtype)
        self.translation = translation.astype(dtype)

    def rank(self): # TODO: use python property
        return self.translation.shape[0]

    def dtype(self): # TODO: use python property
        return self.translation.dtype

    def __call__(self, points):
        if isinstance(points, int) or isinstance(points, float):
            assert points == 0
            points = np.asarray([points] * self.rank()).astype(self.dtype()) # TODO: just return translation?
        elif isinstance(points, tuple) or isinstance(points, list):
            points = np.asarray(points)
        points = points.astype(self.dtype())

        batch_shape = points.shape[:-1]
        points = np.reshape(points, (1 if len(batch_shape) == 0 else np.prod(batch_shape), points.shape[-1]))

        points = np.transpose(np.matmul(self.rotation, np.transpose(points, (1, 0))), (1, 0))
        points = points + self.translation[np.newaxis, :]

        points = np.reshape(points, (*batch_shape, points.shape[-1]))

        return points

    def __str__(self):
        return f"{{t={self.translation.tolist()} R={self.rotation.tolist()}}}"

    def __repr__(self):
        return f"cosy.np.Rigid(translation=np.asarray({self.translation.tolist()}), " \
            +  f"rotation=np.asarray({self.rotation.tolist()}), " \
            +  f"dtype=\"{str(self.dtype())}\")"

    def __mul__(self, other):
        return Rigid(np.matmul(self.rotation, other.rotation), self(other.translation), dtype=self.dtype())

    def inverse(self):
        inv_rotation = np.transpose(self.rotation, (1, 0))
        inv_translation = np.matmul(inv_rotation, -self.translation)
        return Rigid(inv_rotation, inv_translation, dtype=self.dtype())

    def __truediv__(self, other):
        return self * other.inverse()

    def astype(self, dtype):
        return Rigid(self.rotation, self.translation, dtype=dtype)

    @staticmethod
    def random2d(translation_stddev, angle_stddev, dtype="float32"):
        translation_angle = np.random.uniform(size=(1,), low=0.0, high=math.pi).astype(dtype)[0]
        translation_distance = np.random.normal(size=(1,), loc=0.0, scale=translation_stddev).astype(dtype)[0]
        translation = np.matmul(angle_to_rotation_matrix(translation_angle), np.stack([translation_distance, 0.0], axis=0))

        angle = np.random.normal(size=(1,), loc=0.0, scale=angle_stddev).astype(dtype)[0]
        return Rigid(angle_to_rotation_matrix(angle), translation, dtype=dtype)

    @staticmethod
    def least_squares(from_points, to_points):
        from_points = np.asarray(from_points)
        to_points = np.asarray(to_points)
        assert from_points.shape == to_points.shape
        rank = from_points.shape[1]

        mean1 = np.mean(from_points, axis=0)
        mean2 = np.mean(to_points, axis=0)
        from_points = from_points - mean1
        to_points = to_points - mean2

        W = np.matmul(np.transpose(to_points, (1, 0)), from_points)


        if rank == 2:
            y = 0.5 * (W[1, 0] - W[0, 1])
            x = 0.5 * (W[0, 0] + W[1, 1])
            r = math.sqrt(x * x + y * y)
            rotation_matrix = np.stack([
                np.stack([x, -y]),
                np.stack([y, x])
            ]) * (1.0 / r)
        else:
            u, s, vT = np.linalg.svd(W)
            rotation_matrix = np.dot(vT.T, u.T)

        translation = mean2 - np.matmul(rotation_matrix, mean1)

        return Rigid(rotation_matrix, translation)

    @staticmethod
    def from_matrix(m):
        rank = m.shape[0] - 1
        if not np.allclose(m[rank, :-1], 0.0) or not np.allclose(m[rank, rank], 1.0):
            raise ValueError("Not a valid Rigid transformation matrix")
        return Rigid(
            rotation=m[:rank, :rank],
            translation=m[:rank, rank]
        )

    def to_matrix(self):
        rank = self.rank()
        m = np.eye(rank + 1, dtype=self.dtype())
        m[:rank, :rank] = self.rotation
        m[:rank, rank] = self.translation
        return m

    @staticmethod
    def slerp(transform1, transform2=None, amount=0.5):
        assert 0 <= amount and amount <= 1
        if transform2 is None:
            transform2 = transform1
            transform1 = Rigid(dtype=transform1.dtype(), rank=transform1.rank()) # Identity
        assert transform1.rank() == transform2.rank()

        if amount == 0:
            return transform1
        elif amount == 1:
            return transform2

        if transform1.rank() == 2:
            angle1 = rotation_matrix_to_angle(transform1.rotation)
            angle2 = rotation_matrix_to_angle(transform2.rotation)
            return Rigid(
                rotation=angle1 + amount * (angle2 - angle1),
                translation=transform1.translation + amount * (transform2.translation - transform1.translation),
            )
        elif transform1.rank() == 3:
            q1 = Quaternion(matrix=transform1.rotation, rtol=1e-04, atol=1e-06)
            q2 = Quaternion(matrix=transform2.rotation, rtol=1e-04, atol=1e-06)
            q = Quaternion.slerp(q1, q2, amount)
            return Rigid(
                rotation=q.rotation_matrix,
                translation=transform1.translation + amount * (transform2.translation - transform1.translation),
            )
        else:
            raise ValueError(f"Slerp for transformation with rank {self.rank()} not supported")

    def to_json(self):
        return {
            "rotation": self.rotation.tolist(),
            "translation": self.translation.tolist(),
            "dtype": str(self.dtype()),
        }

    @staticmethod
    def from_json(data):
        return Rigid(
            rotation=data["rotation"],
            translation=data["translation"],
            dtype=data["dtype"],
        )

# TODO implement this with member Rigid transformation
class ScaledRigid: # TODO: implement for nd
    def __init__(self, rotation=None, translation=None, scale=None, dtype=None, rank=None):
        if isinstance(rotation, Rigid) and translation is None and scale is None and dtype is None:
            rigid = rotation
            rotation = rigid.rotation
            translation = rigid.translation
            dtype = rigid.dtype()
            scale = 1

        if isinstance(rotation, float) or isinstance(rotation, int):
            rotation = angle_to_rotation_matrix(float(rotation))
        if not rotation is None:
            rotation = np.asarray(rotation)
        if not translation is None:
            translation = np.asarray(translation)
        if not scale is None:
            scale = np.asarray(scale)

        if dtype is None:
            if not rotation is None:
                dtype = rotation.dtype
            elif not translation is None:
                dtype = translation.dtype
            elif not scale is None:
                dtype = scale.dtype
            else:
                raise ValueError("Expected dtype argument when rotation, translation and scale are not given")

        if rank is None:
            if not translation is None:
                rank = translation.shape[0]
            elif not rotation is None:
                rank = rotation.shape[0]
            else:
                raise ValueError("Expected rank argument when rotation and translation are not given")

        if translation is None:
            translation = np.zeros((rank,), dtype=dtype)
        if rotation is None:
            rotation = np.eye(rank, dtype=dtype)
        if scale is None:
            scale = np.asarray((1,))
        scale = np.squeeze(scale)
        if rotation.shape != (rank, rank):
            raise ValueError(f"Got invalid rotation shape {rotation.shape}")
        if translation.shape != (rank,):
            raise ValueError(f"Got invalid translation shape {translation.shape}")
        if scale.shape != ():
            raise ValueError(f"Got invalid scale shape {scale.shape}")
        if not np.all(np.isfinite(rotation)):
            raise ValueError(f"Got non-finite rotation {rotation}")
        if not np.all(np.isfinite(translation)):
            raise ValueError(f"Got non-finite translation {translation}")
        if not np.all(np.isfinite(scale)):
            raise ValueError(f"Got non-finite scale {scale}")
        if not np.allclose(np.matmul(rotation, np.transpose(rotation, (1, 0))), np.eye(rank), atol=1e-6):
            error = np.sum(np.matmul(rotation, np.transpose(rotation, (1, 0))) - np.eye(rank))
            raise ValueError(f"Got non-orthogonal rotation {rotation} with error={error} and dtype {dtype}")

        self.rotation = rotation.astype(dtype)
        self.translation = translation.astype(dtype)
        self.scale = scale.astype(dtype)

    def rank(self):
        return self.translation.shape[0]

    def dtype(self): # TODO: use python property
        return self.translation.dtype

    def __call__(self, points):
        if isinstance(points, int) or isinstance(points, float):
            assert points == 0
            points = np.asarray([points] * self.rank()).astype(self.dtype()) # TODO: just return translation?
        elif isinstance(points, tuple) or isinstance(points, list):
            points = np.asarray(points)
        points = points.astype(self.dtype())
        single_point = len(points.shape) == 1
        if single_point:
            points = points[np.newaxis, :]
        points = np.transpose(np.matmul(self.scale[..., np.newaxis, np.newaxis] * self.rotation, np.transpose(points, (1, 0))), (1, 0))
        points = points + self.translation[np.newaxis, :]
        if single_point:
            points = points[0, :]
        return points

    def __str__(self):
        return f"{{t={self.translation.tolist()} R={self.rotation.tolist()}, s={self.scale.tolist()}}}"

    def __repr__(self):
        return f"cosy.np.ScaledRigid(translation=np.asarray({self.translation.tolist()}), " \
            +  f"rotation=np.asarray({self.rotation.tolist()}), " \
            +  f"scale=np.asarray({self.scale.tolist()}), " \
            +  f"dtype=\"{str(self.dtype())}\")"

    def __mul__(self, other):
        if isinstance(other, Rigid):
            other = ScaledRigid(other)
        return ScaledRigid(np.matmul(self.rotation, other.rotation), self(other.translation), self.scale * other.scale, dtype=self.dtype())

    def inverse(self):
        inv_rotation = np.transpose(self.rotation, (1, 0))
        inv_translation = np.matmul(inv_rotation, -self.translation) / self.scale
        inv_scale = 1.0 / self.scale
        return ScaledRigid(inv_rotation, inv_translation, inv_scale, dtype=self.dtype())

    def __truediv__(self, other):
        return self * other.inverse()

    def astype(self, dtype):
        return ScaledRigid(self.rotation, self.translation, self.scale, dtype=dtype)

    def to_matrix(self):
        rank = self.rank()
        m = np.eye(rank + 1, dtype=self.dtype())
        m[:rank, :rank] = self.rotation * self.scale
        m[:rank, rank] = self.translation
        return m

    @staticmethod
    def least_squares(from_points, to_points):
        from_points = np.asarray(from_points)
        to_points = np.asarray(to_points)

        mean1 = np.mean(from_points, axis=0)
        mean2 = np.mean(to_points, axis=0)
        from_points = from_points - mean1
        to_points = to_points - mean2

        W = np.matmul(np.transpose(to_points, (1, 0)), from_points)

        y = 0.5 * (W[1, 0] - W[0, 1])
        x = 0.5 * (W[0, 0] + W[1, 1])
        r = math.sqrt(x * x + y * y)
        rotation_matrix = np.stack([
            np.stack([x, -y]),
            np.stack([y, x])
        ]) * (1.0 / r)

        from_points = np.transpose(np.matmul(rotation_matrix, np.transpose(from_points, (1, 0))), (1, 0))

        scale = np.sum(from_points * to_points) / np.sum(from_points * from_points)
        assert scale > 0

        translation = mean2 - scale * np.matmul(rotation_matrix, mean1)

        return ScaledRigid(rotation_matrix, translation, scale)

    def to_json(self):
        return {
            "rotation": self.rotation.tolist(),
            "translation": self.translation.tolist(),
            "scale": self.scale.tolist(),
            "dtype": str(self.dtype()),
        }

    @staticmethod
    def from_json(data):
        return ScaledRigid(
            rotation=data["rotation"],
            translation=data["translation"],
            scale=data["scale"],
            dtype=data["dtype"],
        )
