#pragma once

#include <xti/typedefs.h>
#include <xti/util.h>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xtensor.hpp>
#include <xtensor/xio.hpp>
#include <xtensor-blas/xlinalg.hpp>

#ifdef COSY_CEREAL_INCLUDED
#include <cereal/access.hpp>
#include <xti/cereal.h>
#endif

namespace cosy {

template <typename TScalar, typename = std::enable_if_t<!xti::is_xtensor_v<TScalar>>>
TScalar radians(TScalar degrees)
{
  return degrees / 180 * xt::numeric_constants<TScalar>::PI;
}

template <typename TTensor, typename = std::enable_if_t<xti::is_xtensor_v<TTensor>>>
auto radians(TTensor&& tensor)
{
  return std::forward<TTensor>(tensor) / 180 * xt::numeric_constants<xti::elementtype_t<TTensor>>::PI;
}

template <typename TScalar, typename = std::enable_if_t<!xti::is_xtensor_v<TScalar>>>
TScalar degrees(TScalar degrees)
{
  return degrees * 180 / xt::numeric_constants<TScalar>::PI;
}

template <typename TTensor, typename = std::enable_if_t<xti::is_xtensor_v<TTensor>>>
auto degrees(TTensor&& tensor)
{
  return std::forward<TTensor>(tensor) * 180 / xt::numeric_constants<xti::elementtype_t<TTensor>>::PI;
}

template <typename TScalar>
xti::mat2T<std::decay_t<TScalar>> angle_to_rotation_matrix(TScalar angle)
{
  return {{std::cos(angle), -std::sin(angle)}, {std::sin(angle), std::cos(angle)}};
}

template <typename TRotationMatrix>
auto rotation_matrix_to_angle(TRotationMatrix&& rotation_matrix)
{
  return std::atan2(rotation_matrix(1, 0), rotation_matrix(0, 0));
}

template <typename TScalar>
auto normalize_angle(TScalar angle, TScalar lower = -xt::numeric_constants<TScalar>::PI, TScalar upper = xt::numeric_constants<TScalar>::PI)
{
  static const TScalar pi = xt::numeric_constants<TScalar>::PI;
  while (angle >= upper)
  {
    angle -= 2 * pi;
  }
  while (angle < lower)
  {
    angle += 2 * pi;
  }
  return angle;
}

template <typename TVec1, typename TVec2>
auto angle(TVec1&& vec1, TVec2&& vec2, bool clockwise = false)
{
  auto angle = std::atan2(vec2(1), vec2(0)) - std::atan2(vec1(1), vec1(0));

  return clockwise ? -angle : angle;
}

template <typename TScalar, size_t TRank>
class Rotation
{
private:
  xti::matXT<TScalar, TRank> m_rotation;

public:
  Rotation()
    : m_rotation(xt::eye<TScalar>(TRank))
  {
  }

  template <bool TDummy = true, typename = std::enable_if_t<TDummy && TRank == 2, void>>
  Rotation(TScalar angle)
    : m_rotation(angle_to_rotation_matrix(angle))
  {
  }

  Rotation(xti::matXT<TScalar, TRank + 1> transformation_matrix)
    : m_rotation(xt::view(transformation_matrix, xt::range(0, TRank), xt::range(0, TRank)))
  {
    // TODO: check that all other elements of matrix are 0, with epsilon
    // if (xt::view(transformation_matrix, TRank, xt::range(0, TRank)) != 0 || transformation_matrix(TRank, TRank) != 1)
  }

  Rotation(xti::matXT<TScalar, TRank> rotation)
    : m_rotation(rotation)
  {
  }

  template <typename TScalar2>
  Rotation(const Rotation<TScalar2, TRank>& other)
    : m_rotation(other.m_rotation)
  {
  }

  template <typename TScalar2>
  Rotation<TScalar, TRank>& operator=(const Rotation<TScalar2, TRank>& other)
  {
    this->m_rotation = other.m_rotation;
    return *this;
  }

  auto transform(xti::vecXT<TScalar, TRank> point) const
  {
    return xt::linalg::dot(m_rotation, point);
  }

  template <typename TTensor>
  auto transform_all(TTensor&& points) const
  {
    if (points.shape()[1] != TRank)
    {
      throw std::invalid_argument(XTI_TO_STRING("Points tensor must have shape (n, " << TRank << "), got shape " << xt::adapt(points.shape())));
    }
    return xt::transpose(xt::eval(xt::linalg::dot(m_rotation, xt::transpose(xt::eval(std::forward<TTensor>(points)), {1, 0}))), {1, 0});
  }

  auto transform_inverse(xti::vecXT<TScalar, TRank> point) const
  {
    return xt::linalg::dot(xt::transpose(m_rotation, {1, 0}), point);
  }

  template <typename TTensor>
  auto transform_all_inverse(TTensor&& points) const
  {
    if (points.shape()[1] != TRank)
    {
      throw std::invalid_argument(XTI_TO_STRING("Points tensor must have shape (n, " << TRank << "), got shape " << xt::adapt(points.shape())));
    }
    return xt::transpose(xt::eval(xt::linalg::dot(xt::transpose(m_rotation, {1, 0}), xt::transpose(xt::eval(std::forward<TTensor>(points)), {1, 0})), {1, 0}));
  }

  Rotation<TScalar, TRank> inverse() const
  {
    Rotation<TScalar, TRank> result;
    result.get_rotation() = xt::transpose(m_rotation, {1, 0});
    return result;
  }

  Rotation<TScalar, TRank>& operator*=(const Rotation<TScalar, TRank>& right)
  {
    m_rotation = xt::linalg::dot(m_rotation, right.get_rotation());
    return *this;
  }

  xti::matXT<TScalar, TRank>& get_rotation()
  {
    return m_rotation;
  }

  const xti::matXT<TScalar, TRank>& get_rotation() const
  {
    return m_rotation;
  }

  bool flips() const
  {
    return xt::linalg::det(m_rotation) < 0;
  }

  template <typename TScalar2, size_t TRank2>
  friend class Rotation;
};

template <typename TScalar, size_t TRank>
Rotation<TScalar, TRank> operator*(const Rotation<TScalar, TRank>& left, const Rotation<TScalar, TRank>& right)
{
  return Rotation<TScalar, TRank>(xt::linalg::dot(left.get_rotation(), right.get_rotation()));
}

template <typename TScalar, size_t TRank>
Rotation<TScalar, TRank> operator/(const Rotation<TScalar, TRank>& left, const Rotation<TScalar, TRank>& right)
{
  return left * right.inverse();
}

template <typename TScalar, size_t TRank>
std::ostream& operator<<(std::ostream& stream, const Rotation<TScalar, TRank>& transform)
{
  return stream << "Rotation(" << " R=" << transform.get_rotation() << ")";
}

#ifdef COSY_CEREAL_INCLUDED
template <typename TArchive, typename TScalar, size_t TRank>
void save(TArchive& archive, const cosy::Rotation<TScalar, TRank>& transform)
{
  archive(transform.get_rotation());
}

template <typename TArchive, typename TScalar, size_t TRank>
void load(TArchive& archive, cosy::Rotation<TScalar, TRank>& transform)
{
  archive(transform.get_rotation());
}
#endif

} // cosy
