#pragma once

#include <cosy/affine.h>
#include <xti/util.h>
#include <xtensor/xadapt.hpp>
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-blas/xlinalg.hpp>
#include <xtensor-io/xnpz.hpp>
#include <filesystem>
#include <fstream>
#include "../exception.h"
#include <cereal/access.hpp>
#include <xti/cereal.h>
#include <cereal/types/memory.hpp>

namespace georegdata::ground {

class Lidar;

class LidarId
{
public:
  LidarId(std::string name, xt::xtensor<double, 2> map, cosy::Rigid<double, 3> loaded_to_ego)
    : m_name(name)
    , m_map(map)
    , m_loaded_to_ego(loaded_to_ego)
  {
  }

  LidarId(std::string name, xt::xtensor<double, 2> map)
    : LidarId(name, map, cosy::Rigid<double, 3>())
  {
  }

  LidarId(std::string name)
    : LidarId(name, xt::eye<double>({3, 3}))
  {
  }

  virtual ~LidarId()
  {
  }

  Lidar load() const;

  const std::string& get_name() const
  {
    return m_name;
  }

  xt::xtensor<double, 2> get_map() const
  {
    return m_map;
  }

  const cosy::Rigid<double, 3>& get_loaded_to_ego() const
  {
    return m_loaded_to_ego;
  }

  template <typename TArchive>
  static void save(TArchive& archive, std::shared_ptr<LidarId> lidar_id);

  template <typename TArchive>
  static std::shared_ptr<LidarId> load(TArchive& archive);

protected:
  virtual xt::xtensor<double, 2> load_points() const = 0;

  virtual std::unique_ptr<LidarId> clone() const = 0;

private:
  std::string m_name;
  xt::xtensor<double, 2> m_map;
  cosy::Rigid<double, 3> m_loaded_to_ego;
};

class Lidar
{
public:
  Lidar(std::unique_ptr<LidarId>&& id, xt::xtensor<double, 2>&& points_ego)
    : m_id(std::move(id))
    , m_points_ego(std::move(points_ego))
  {
  }

  const xt::xtensor<double, 2>& get_points() const
  {
    return m_points_ego;
  }

private:
  std::unique_ptr<LidarId> m_id;
  xt::xtensor<double, 2> m_points_ego;
};

Lidar LidarId::load() const
{
  xt::xtensor<double, 2> points_ego = this->load_points();
  if (points_ego.shape()[0] == 0)
  {
    throw LoadException(XTI_TO_STRING("Array must contain at least one point"));
  }
  if (points_ego.shape()[1] < m_map.shape()[1])
  {
    throw LoadException(XTI_TO_STRING("Expected at least " << m_map.shape()[1] << " dimensions for lidar points, got " << points_ego.shape()[1] << " dimensions"));
  }
  else if (points_ego.shape()[1] > m_map.shape()[1])
  {
    points_ego = xt::view(std::move(points_ego), xt::all(), xt::range(0, m_map.shape()[1]));
  }
  points_ego = xt::linalg::dot(std::move(points_ego), xt::transpose(m_map, {1, 0}));
  points_ego = m_loaded_to_ego.transform_all(std::move(points_ego));
  return Lidar(this->clone(), std::move(points_ego));
}

class NpzLidarId : public LidarId
{
public:
  NpzLidarId(std::string name, std::filesystem::path file, xt::xtensor<double, 2> map, cosy::Rigid<double, 3> loaded_to_ego)
    : LidarId(name, map, loaded_to_ego)
    , m_file(file)
  {
  }

  const std::filesystem::path& get_file() const
  {
    return m_file;
  }

  template <typename TArchive>
  static void load_and_construct(TArchive& archive, cereal::construct<NpzLidarId>& construct)
  {
    std::string name;
    archive(name);
    xt::xtensor<double, 2> map;
    archive(map);
    cosy::Rigid<double, 3> loaded_to_ego;
    archive(loaded_to_ego);
    std::string file;
    archive(file);
    construct(name, file, map, loaded_to_ego);
  }

protected:
  xt::xtensor<double, 2> load_points() const
  {
    auto data = xt::load_npz(m_file.string());
    std::string key = "";
    if (data.size() == 1)
    {
      key = data.begin()->first;
    }
    else if (data.count("arr_0"))
    {
      key = "arr_0";
    }
    else if (data.count("points"))
    {
      key = "points";
    }
    else
    {
      throw LoadException(XTI_TO_STRING("Npz file has more than one key"));
    }

    xt::xtensor<double, 2> points_ego;
    if (data[key].m_typestring == "<f4")
    {
      points_ego = data[key].cast<float>();
    }
    else if (data[key].m_typestring == "<f8")
    {
      points_ego = data[key].cast<double>();
    }
    else
    {
      throw LoadException(XTI_TO_STRING("Npz file has invalid datatype " << data[key].m_typestring));
    }

    return points_ego;
  }

  std::unique_ptr<LidarId> clone() const
  {
    return std::make_unique<NpzLidarId>(*this);
  }

private:
  std::filesystem::path m_file;
};

template <typename TArchive>
void save(TArchive& archive, const NpzLidarId& lidar_id)
{
  archive(lidar_id.get_name());
  archive(lidar_id.get_map());
  archive(lidar_id.get_loaded_to_ego());
  archive(std::string(lidar_id.get_file()));
}

template <typename TScalar>
class BinLidarId : public LidarId
{
public:
  BinLidarId(std::string name, std::filesystem::path file, xt::xtensor<double, 2> map, cosy::Rigid<double, 3> loaded_to_ego, bool crop_to_size)
    : LidarId(name, map, loaded_to_ego)
    , m_file(file)
    , m_crop_to_size(crop_to_size)
  {
  }

  const std::filesystem::path& get_file() const
  {
    return m_file;
  }

  bool crop_to_size() const
  {
    return m_crop_to_size;
  }

  template <typename TArchive>
  static void load_and_construct(TArchive& archive, cereal::construct<BinLidarId<TScalar>>& construct)
  {
    std::string name;
    archive(name);
    xt::xtensor<double, 2> map;
    archive(map);
    cosy::Rigid<double, 3> loaded_to_ego;
    archive(loaded_to_ego);
    std::string file;
    archive(file);
    bool crop_to_size;
    archive(crop_to_size);
    construct(name, file, map, loaded_to_ego, crop_to_size);
  }

protected:
  xt::xtensor<double, 2> load_points() const
  {
    std::ifstream input(m_file, std::ios::binary);
    if (!input)
    {
      throw LoadException(XTI_TO_STRING("Failed to open file " << m_file));
    }
    std::vector<uint8_t> buffer(std::istreambuf_iterator<char>(input), {});
    size_t size = buffer.size() / sizeof(TScalar);
    size_t dimensions = this->get_map().shape()[1];
    if (size / dimensions * dimensions != size)
    {
      if (m_crop_to_size)
      {
        size = size / dimensions * dimensions;
      }
      else
      {
        throw LoadException(XTI_TO_STRING("Expected number of elements in " << m_file << " to be multiple of " << dimensions << ", got " << size));
      }
    }

    std::vector<std::size_t> shape = {size / dimensions, dimensions};
    xt::xtensor<double, 2> points_ego = xt::cast<double>(xt::adapt(reinterpret_cast<TScalar*>(buffer.data()), size, xt::no_ownership(), shape));
    return points_ego;
  }

  std::unique_ptr<LidarId> clone() const
  {
    return std::make_unique<BinLidarId<TScalar>>(*this);
  }

private:
  std::filesystem::path m_file;
  bool m_crop_to_size;
};

template <typename TArchive, typename TScalar>
void save(TArchive& archive, const BinLidarId<TScalar>& lidar_id)
{
  archive(lidar_id.get_name());
  archive(lidar_id.get_map());
  archive(lidar_id.get_loaded_to_ego());
  archive(std::string(lidar_id.get_file()));
  archive(lidar_id.crop_to_size());
}

class DummyLidarId : public LidarId
{
public:
  DummyLidarId(std::string name)
    : LidarId(name)
  {
  }

  template <typename TArchive>
  static void load_and_construct(TArchive& archive, cereal::construct<DummyLidarId>& construct)
  {
    std::string name;
    archive(name);
    construct(name);
  }

protected:
  xt::xtensor<double, 2> load_points() const
  {
    static const size_t NUM_ANGLES = 100;
    static const float RADIUS = 10.0;
    static const size_t HEIGHTS = 2;

    xt::xtensor<double, 2> points_ego({HEIGHTS * NUM_ANGLES, 3});
    for (size_t i = 0; i < NUM_ANGLES; ++i)
    {
      float yaw = ((float) i) / NUM_ANGLES * 2 * xt::numeric_constants<float>::PI;
      bool front = std::cos(yaw) > 0;
      bool _45 = std::abs(yaw - 0.25 * xt::numeric_constants<float>::PI) < 0.1
              || std::abs(yaw - 0.75 * xt::numeric_constants<float>::PI) < 0.1
              || std::abs(yaw - 1.25 * xt::numeric_constants<float>::PI) < 0.1
              || std::abs(yaw - 1.75 * xt::numeric_constants<float>::PI) < 0.1;
      bool left = yaw < xt::numeric_constants<float>::PI;

      float up = front ? 0 : 2;
      if (_45)
      {
        up += 1;
      }
      // float up = 0;

      for (size_t h = 0; h < HEIGHTS; ++h)
      {
        points_ego(HEIGHTS * i + h, 0) = std::cos(yaw) * RADIUS;
        points_ego(HEIGHTS * i + h, 1) = std::sin(yaw) * RADIUS;
        points_ego(HEIGHTS * i + h, 2) = up + h + (left && (i % 2 == 0) ? 0.5 : 0);
      }
    }

    return points_ego;
  }

  std::unique_ptr<LidarId> clone() const
  {
    return std::make_unique<DummyLidarId>(*this);
  }
};

template <typename TArchive>
void save(TArchive& archive, const DummyLidarId& lidar_id)
{
  archive(lidar_id.get_name());
}


template <typename TArchive>
void LidarId::save(TArchive& archive, std::shared_ptr<LidarId> lidar_id)
{
  #define SUBCLASS(NAME, ID) \
    { \
      std::shared_ptr<NAME> ptr = std::dynamic_pointer_cast<NAME>(lidar_id); \
      if (ptr) \
      { \
        uint8_t id = ID; \
        archive(id); \
        archive(ptr); \
        return; \
      } \
    }
  SUBCLASS(NpzLidarId, 1)
  SUBCLASS(BinLidarId<float>, 2)
  SUBCLASS(DummyLidarId, 3)
  #undef SUBCLASS
  throw LoadException("Invalid lidar subclass");
}

template <typename TArchive>
std::shared_ptr<LidarId> LidarId::load(TArchive& archive)
{
  uint8_t id;
  archive(id);
  #define SUBCLASS(NAME, ID) \
    case ID: \
    { \
      std::shared_ptr<NAME> ptr; \
      archive(ptr); \
      return ptr; \
    }
  switch (id)
  {
    SUBCLASS(NpzLidarId, 1)
    SUBCLASS(BinLidarId<float>, 2)
    SUBCLASS(DummyLidarId, 3)
    default: throw LoadException("Invalid lidar subclass id");
  }
  #undef SUBCLASS
}

} // end of ns georegdata::ground
