from pyquaternion import Quaternion
import numpy as np
import os, tqdm, cosy, math, imageio
from collections import defaultdict
from . import dataset
from georegdata.ground import FrameId, CameraId, DummyLidarId

def pinhole_intr(shape):
    pinhole_fov = math.radians(75.0)

    fx = 0.5 * shape[0] / math.tan(0.5 * pinhole_fov)
    fy = fx
    cx = shape[1] / 2.0
    cy = shape[0] / 2.0
    pinhole_intr = np.asarray([
        [fx, 0, cx],
        [0, fy, cy],
        [0, 0, 1.0],
    ]).astype("float32")

    return pinhole_intr

pinhole_thetas = [math.radians(a) for a in [0.0, 90.0, 180.0, 270.0]]
pinhole_names = ["front", "right", "back", "left"]

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")
    origego_to_ego = cosy.np.Rigid(translation=np.asarray([0, 0, 2.5]))

    bearings_all = {}
    for city in os.listdir(path):
        city_path = os.path.join(path, city)
        if os.path.isdir(city_path) and city != "splits":
            with open(os.path.join(city_path, "bearings.txt")) as f:
                bearings = f.readlines()
            bearings = [l.strip() for l in bearings]
            bearings = [l for l in bearings if len(l) > 0]
            bearings = [l.split(" ") for l in bearings]
            bearings = {",".join(l[0].split(",")[:3]): float(l[1]) for l in bearings}
            bearings_all[city] = bearings
    bearings = bearings_all

    scenes = defaultdict(list)
    image_sizes = None
    pinhole_shape = None

    tasks = []
    for city in os.listdir(path):
        city_path = os.path.join(path, city)
        if os.path.isdir(city_path) and city != "splits":
            pano_path = os.path.join(city_path, "panorama")
            for image_file in os.listdir(pano_path):
                tasks.append((city, image_file))
    tasks = sorted(tasks)

    for city, image_file in tqdm.tqdm(tasks, desc="Vigor"):
        # filename: -0coGdhsIT67C296Wn4VsA,40.720209,-73.984511,.jpg
        latlon = image_file.split(",")
        latlon = np.asarray([float(latlon[1]), float(latlon[2])])
        bearing = bearings[city][image_file[:-5]]

        world_to_epsg3857 = cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(latlon)
        origego_to_world = cosy.np.Rigid(rotation=Quaternion(axis=(0, 0, 1), radians=epsg4326_to_epsg3857.transform_angle(math.radians(bearing))).rotation_matrix)

        cameras = []
        for camera_name, theta in zip(pinhole_names, pinhole_thetas):
            pinhole_to_origego = cosy.np.Rigid(rotation=Quaternion(axis=(0, 0, 1), radians=-theta).rotation_matrix) \
                               * cosy.np.Rigid(rotation=np.asarray([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]).astype("float32"))

            pinhole_file = os.path.join(path, city, "pinhole", camera_name, image_file)
            if pinhole_shape is None:
                pinhole_shape = np.asarray(imageio.imread(pinhole_file).shape[:2])

            intr = pinhole_intr(pinhole_shape)

            cameras.append(CameraId(
                name=camera_name,
                intr=intr,
                resolution=pinhole_shape,
                ego_to_camera=pinhole_to_origego.inverse() * origego_to_ego.inverse(),
                image_file=pinhole_file,
            ))

        scene = [FrameId(
            dataset_name="vigor",
            location=city,
            scene_id=image_file[:-5],
            index_in_scene=0,
            timestamp=0,
            ego_to_world=origego_to_world * origego_to_ego.inverse(),
            world_to_crs=world_to_epsg3857,
            crs=epsg3857,
            epsg4326_to_crs=epsg4326_to_epsg3857,
            cameras=cameras,
            lidars=[DummyLidarId("dummy")],
        )]

        if image_sizes is None:
            image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in cameras])

        scenes[city].append(scene)

    return {
        location: dataset.Dataset(
            name="vigor",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in scenes.items()
    }
