#pragma once

#include <cosy/affine.h>
#include <cosy/camera.h>
#include <xti/util.h>
#include <xtensor/xadapt.hpp>
#include <xtensor/xtensor.hpp>
#include <xti/opencv.h>
#include <xtensor-blas/xlinalg.hpp>
#include <filesystem>
#include "../exception.h"
#include <cereal/access.hpp>
#include <xti/cereal.h>
#include <cereal/types/memory.hpp>
#include <xtensor/xrandom.hpp>
#include <tiledwebmaps/directory.h>

namespace georegdata::ground {

class Camera;

class CameraId
{
public:
  CameraId(std::string name, cosy::PinholeK<double, 3> projection, xti::vec2s resolution, cosy::Rigid<double, 3> ego_to_camera, std::filesystem::path image_file)
    : m_name(name)
    , m_projection(projection)
    , m_resolution(resolution)
    , m_ego_to_camera(ego_to_camera)
    , m_image_file(image_file)
  {
  }

  virtual Camera load(const xt::xtensor<double, 2>& points_ego) const;

  template <typename TAngleFn, typename TScaleFn>
  Camera load(const xt::xtensor<double, 2>& points_ego, TAngleFn&& angle, TScaleFn&& scale, xti::vec2s max_shape) const;

  xti::vec2s load_resolution() const
  {
    cv::Mat image_cv = tiledwebmaps::safe_imread(m_image_file);
    return xti::vec2s({static_cast<size_t>(image_cv.rows), static_cast<size_t>(image_cv.cols)});
  }

  const std::string& get_name() const
  {
    return m_name;
  }

  const cosy::PinholeK<double, 3>& get_projection() const
  {
    return m_projection;
  }

  const xti::vec2s& get_resolution() const
  {
    return m_resolution;
  }

  const cosy::Rigid<double, 3>& get_ego_to_camera() const
  {
    return m_ego_to_camera;
  }

  const std::filesystem::path& get_image_file() const
  {
    return m_image_file;
  }

  template <typename TArchive>
  static void load_and_construct(TArchive& archive, cereal::construct<CameraId>& construct)
  {
    std::string name;
    archive(name);
    cosy::PinholeK<double, 3> projection;
    archive(projection);
    xti::vec2s resolution;
    archive(resolution);
    cosy::Rigid<double, 3> ego_to_camera;
    archive(ego_to_camera);
    std::string image_file;
    archive(image_file);
    construct(name, projection, resolution, ego_to_camera, image_file);
  }

private:
  std::string m_name;
  cosy::PinholeK<double, 3> m_projection;
  xti::vec2s m_resolution;
  cosy::Rigid<double, 3> m_ego_to_camera;
  std::filesystem::path m_image_file;
};

template <typename TArchive>
void save(TArchive& archive, const CameraId& camera_id)
{
  archive(camera_id.get_name());
  archive(camera_id.get_projection());
  archive(camera_id.get_resolution());
  archive(camera_id.get_ego_to_camera());
  archive(std::string(camera_id.get_image_file()));
}

class AugmentedCameraId : public CameraId
{
public:
  AugmentedCameraId(const CameraId& camera_id, double angle, double scale, xti::vec2s max_shape)
    : CameraId(camera_id)
    , m_angle(angle)
    , m_scale(scale)
    , m_max_shape(max_shape)
  {
  }

  AugmentedCameraId(const CameraId& camera_id, double angle, double scale)
    : AugmentedCameraId(camera_id, angle, scale, xti::vec2s({0, 0}))
  {
  }

  Camera load(const xt::xtensor<double, 2>& points_ego) const;

private:
  double m_angle;
  double m_scale;
  xti::vec2s m_max_shape;
};

class FixedIntrCameraId : public CameraId
{
public:
  FixedIntrCameraId(const CameraId& camera_id, double focal_length, xti::vec2s max_shape)
    : CameraId(camera_id)
    , m_focal_length(focal_length)
    , m_max_shape(max_shape)
  {
  }

  FixedIntrCameraId(const CameraId& camera_id, double focal_length)
    : FixedIntrCameraId(camera_id, focal_length, xti::vec2s({0, 0}))
  {
  }

  Camera load(const xt::xtensor<double, 2>& points_ego) const;

private:
  double m_focal_length;
  xti::vec2s m_max_shape;
};

class Camera
{
public:
  Camera(CameraId id, xt::xtensor<uint8_t, 3>&& image, xt::xtensor<bool, 2>&& image_mask, xt::xtensor<double, 2>&& pixels, xt::xtensor<double, 1>&& points_depth, xt::xtensor<bool, 1>&& points_mask)
    : m_id(id)
    , m_image(std::move(image))
    , m_image_mask(std::move(image_mask))
    , m_pixels(std::move(pixels))
    , m_points_depth(std::move(points_depth))
    , m_points_mask(std::move(points_mask))
  {
    if (m_pixels.shape()[0] != m_points_depth.shape()[0] || m_pixels.shape()[0] != m_points_mask.shape()[0])
    {
      throw std::invalid_argument(XTI_TO_STRING("pixels, points_depth and points_mask must have same number of points, got " << m_pixels.shape()[0] << " " << m_points_depth.shape()[0] << " " << m_points_mask.shape()[0]));
    }
    if (m_pixels.shape()[1] != 2)
    {
      throw std::invalid_argument(XTI_TO_STRING("pixels must have shape (n, 2), got shape " << xt::adapt(m_pixels.shape())));
    }
  }

  const CameraId& get_id() const
  {
    return m_id;
  }

  const xt::xtensor<uint8_t, 3>& get_image() const
  {
    return m_image;
  }

  const xt::xtensor<bool, 2>& get_image_mask() const
  {
    return m_image_mask;
  }

  const xt::xtensor<double, 2>& get_pixels() const
  {
    return m_pixels;
  }

  const xt::xtensor<double, 1>& get_points_depth() const
  {
    return m_points_depth;
  }

  const xt::xtensor<bool, 1>& get_points_mask() const
  {
    return m_points_mask;
  }

private:
  CameraId m_id;
  xt::xtensor<uint8_t, 3> m_image;
  xt::xtensor<bool, 2> m_image_mask;
  xt::xtensor<double, 2> m_pixels;
  xt::xtensor<double, 1> m_points_depth;
  xt::xtensor<bool, 1> m_points_mask;
};

void transform(xt::xtensor<uint8_t, 3> image, double angle, double scale, xti::vec2d center, xt::xtensor<uint8_t, 3>& dest_image, xt::xtensor<bool, 2>& dest_mask)
{
  if (scale < 1)
  {
    cv::Mat image_cv = xti::to_opencv(image);
    double sigma = (1.0 / scale - 1) / 2;
    size_t kernel_size = static_cast<size_t>(std::ceil(sigma) * 4) + 1;
    cv::GaussianBlur(image_cv, image_cv, cv::Size(kernel_size, kernel_size), sigma, sigma);
  }

  xti::vec2s src_shape = xti::vec2d({static_cast<double>(image.shape()[0]), static_cast<double>(image.shape()[1])});
  xti::vec2s dest_shape = src_shape * scale;
  xti::vec2d scales = xt::cast<double>(dest_shape) / xt::cast<double>(src_shape);

  cosy::Rotation<double, 2> rotation(angle);

  dest_image = xt::xtensor<uint8_t, 3>({dest_shape(0), dest_shape(1), 3});
  dest_mask = xt::xtensor<bool, 2>({dest_shape(0), dest_shape(1)});
  for (size_t x = 0; x < dest_shape(0); x++)
  {
    for (size_t y = 0; y < dest_shape(1); y++)
    {
      xti::vec2s dest_pixel({x, y});
      xti::vec2d src_pixel = (rotation.transform(xt::cast<double>(dest_pixel) / scales - center) + center);

      xti::vec3T<float> value;
      bool m;
      if (xt::all(0 <= src_pixel && src_pixel < src_shape - 1))
      {
        // Linear interpolation
        xti::vec2i src_lower = xt::floor(src_pixel);
        xti::vec2i src_upper = src_lower + 1;
        if (src_lower(0) < 0)
        {
          src_lower(0) += 1;
          src_upper(0) += 1;
        }
        if (src_lower(1) < 0)
        {
          src_lower(1) += 1;
          src_upper(1) += 1;
        }
        if (src_upper(0) >= image.shape()[0])
        {
          src_lower(0) -= 1;
          src_upper(0) -= 1;
        }
        if (src_upper(1) >= image.shape()[1])
        {
          src_lower(1) -= 1;
          src_upper(1) -= 1;
        }
        xti::vec2d t = src_pixel - src_lower;

        auto get = [&](size_t x, size_t y){
          return xti::vec3T<uint8_t>({image(x, y, 0), image(x, y, 1), image(x, y, 2)});
        };

        xti::vec3T<float> value00 = get(src_lower(0), src_lower(1));
        xti::vec3T<float> value01 = get(src_lower(0), src_upper(1));
        xti::vec3T<float> value10 = get(src_upper(0), src_lower(1));
        xti::vec3T<float> value11 = get(src_upper(0), src_upper(1));

        xti::vec3T<float> value0 = (1 - t(1)) * value00 + t(1) * value01;
        xti::vec3T<float> value1 = (1 - t(1)) * value10 + t(1) * value11;

        value = (1 - t(0)) * value0 + t(0) * value1;
        m = true;
      }
      else
      {
        value = xti::vec3T<float>({0.0, 0.0, 0.0});
        m = false;
      }

      // Save pixel
      dest_image(dest_pixel(0), dest_pixel(1), 0) = value(0);
      dest_image(dest_pixel(0), dest_pixel(1), 1) = value(1);
      dest_image(dest_pixel(0), dest_pixel(1), 2) = value(2);

      dest_mask(dest_pixel(0), dest_pixel(1)) = m;
    }
  }
}

Camera CameraId::load(const xt::xtensor<double, 2>& points_ego) const
{
  return this->load(points_ego, [](){return 0.0;}, [](xti::mat3d){return 1.0;}, xti::vec2s({0, 0}));
}

template <typename TAngleFn, typename TScaleFn>
Camera CameraId::load(const xt::xtensor<double, 2>& points_ego, TAngleFn&& angle_fn, TScaleFn&& scale_fn, xti::vec2s max_shape) const
{
  if (points_ego.shape()[1] != 3)
  {
    throw std::invalid_argument(XTI_TO_STRING("Points tensor must have shape (n, 3), got shape " << xt::adapt(points_ego.shape())));
  }

  // Load and transform image
  cv::Mat image_cv = tiledwebmaps::safe_imread(this->get_image_file().string());
  auto image_bgr = xt::view(xti::from_opencv<uint8_t>(std::move(image_cv)), xt::all(), xt::all(), xt::range(0, 3));
  auto image_rgb = xt::view(std::move(image_bgr), xt::all(), xt::all(), xt::range(xt::placeholders::_, xt::placeholders::_, -1));

  xti::vec2s image_shape = xti::vec2s({image_rgb.shape()[0], image_rgb.shape()[1]});

  xti::vec2d center = xt::flip(xt::view(this->get_projection().get_matrix(), xt::range(0, 2), 2), 0);
  center *= xt::cast<double>(xti::vec2s({image_rgb.shape()[0], image_rgb.shape()[1]})) / this->get_resolution();
  xt::xtensor<uint8_t, 3> image;
  xt::xtensor<bool, 2> image_mask;

  auto before_aug_intr = this->get_projection().get_matrix();
  xt::view(before_aug_intr, xt::range(0, 2), xt::all()) *= xt::view(xt::cast<double>(image_shape) / this->get_resolution(), xt::range(xt::placeholders::_, xt::placeholders::_, -1), xt::newaxis());

  double angle = angle_fn();
  double scale = scale_fn(before_aug_intr);

  transform(std::move(image_rgb), -angle, scale, center, image, image_mask);
  image_shape = xti::vec2s({image.shape()[0], image.shape()[1]});
  auto intr = this->get_projection().get_matrix();
  xt::view(intr, xt::range(0, 2), xt::all()) *= xt::view(xt::cast<double>(image_shape) / this->get_resolution(), xt::range(xt::placeholders::_, xt::placeholders::_, -1), xt::newaxis());

  // Crop image to max_shape and update intrinsics
  if (xt::all(max_shape > 0) && xt::any(image_shape > max_shape))
  {
    xti::vec2i padding = xt::maximum(xt::cast<int32_t>(image_shape) - xt::cast<int32_t>(max_shape), 0);
    xti::vec2i offset_front = xt::cast<int32_t>(xt::random::rand<double>({2}) * padding);
    xti::vec2i offset_back = padding - offset_front;
    image = xt::view(
        std::move(image),
        xt::range(offset_front[0], image.shape()[0] - offset_back[0]),
        xt::range(offset_front[1], image.shape()[1] - offset_back[1]),
        xt::all()
    );
    image_mask = xt::view(
        std::move(image_mask),
        xt::range(offset_front[0], image_mask.shape()[0] - offset_back[0]),
        xt::range(offset_front[1], image_mask.shape()[1] - offset_back[1])
    );
    xt::view(intr, xt::range(0, 2), 2) -= xt::flip(offset_front, 0);
  }
  image_shape = xti::vec2s({image.shape()[0], image.shape()[1]});

  cosy::PinholeK<double, 3> new_projection(intr);

  // Transform points to camera and screen coordinates
  cosy::Rigid<double, 3> ego_to_oldcamera = this->get_ego_to_camera();
  cosy::Rigid<double, 3> oldcamera_to_newcamera;
  xt::view(oldcamera_to_newcamera.get_rotation(), xt::range(0, 2), xt::range(0, 2)) = cosy::angle_to_rotation_matrix<double>(-angle);
  cosy::Rigid<double, 3> ego_to_camera = oldcamera_to_newcamera * ego_to_oldcamera;

  xt::xtensor<double, 2> points_s = xt::xtensor<double, 2>(std::array<size_t, 2>{0, 2});
  xt::xtensor<double, 1> depths = xt::xtensor<double, 1>(std::array<size_t, 1>{0});
  xt::xtensor<bool, 1> mask = xt::xtensor<bool, 1>(std::array<size_t, 1>{0});
  if (points_ego.shape()[0] > 0)
  {
    xt::xtensor<double, 2> points_c = ego_to_camera.transform_all(points_ego);
    depths = xt::view(points_c, xt::all(), 2);
    points_s = new_projection.transform_all(points_c);
    points_s = xt::view(std::move(points_s), xt::all(), xt::range(xt::placeholders::_, xt::placeholders::_, -1));

    // Compute points mask
    mask =
         depths > 0
      && xt::view(points_s, xt::all(), 0) >= 0
      && xt::view(points_s, xt::all(), 0) < image.shape()[0]
      && xt::view(points_s, xt::all(), 1) >= 0
      && xt::view(points_s, xt::all(), 1) < image.shape()[1]
    ;
    points_s = xt::clip(points_s, 0, xt::view(image_shape - 1, xt::newaxis(), xt::all()));
    for (size_t i = 0; i < mask.shape()[0]; i++)
    {
      bool& m = mask[i];
      if (m)
      {
        m = image_mask(points_s(i, 0), points_s(i, 1));
      }
    }
  }
  return Camera(CameraId(this->get_name(), new_projection, this->get_resolution(), ego_to_camera, this->get_image_file()), std::move(image), std::move(image_mask), std::move(points_s), std::move(depths), std::move(mask));
}

Camera AugmentedCameraId::load(const xt::xtensor<double, 2>& points_ego) const
{
  return this->CameraId::load(points_ego, [this](){return m_angle;}, [this](xti::mat3d){return m_scale;}, m_max_shape);
}

Camera FixedIntrCameraId::load(const xt::xtensor<double, 2>& points_ego) const
{
  return this->CameraId::load(points_ego, [](){return 0.0;}, [this](xti::mat3d src_intr){
    double src_focal_length = 0.5 * (src_intr(0, 0) + src_intr(1, 1));
    double dest_focal_length = this->m_focal_length;
    return dest_focal_length / src_focal_length;
  }, m_max_shape);
}

} // end of ns georegdata::ground
