import pickle, os, gzip, inspect, cosy, concurrent, io
import numpy as np
from georegdata.backend.ground import FrameId

# pointspace: x forward, y left, z up
# cameraspace: x right, y down, z forward

def finetuned_load(internal_load):
    def load(path, **kwargs):
        location_to_dataset = internal_load(path, **kwargs)
        finetuned_ego_to_crs_file = os.path.join(path, "finetuned_ego_to_crs.pkl.gz")
        if False and os.path.isfile(finetuned_ego_to_crs_file): # TODO: remove
            with gzip.open(finetuned_ego_to_crs_file, "rb") as f:
                finetuned_ego_to_crs = pickle.load(f)
            for dataset in location_to_dataset.values():
                for scene in dataset.scenes:
                    scene_id = scene[0].scene_id
                    if scene_id in finetuned_ego_to_crs:
                        frames_ego_to_crs = finetuned_ego_to_crs[scene_id]
                        del finetuned_ego_to_crs[scene_id]
                        assert len(frames_ego_to_crs) == len(scene)
                        for i in range(len(scene)):
                            params = scene[i].get_params() # TODO: remove/ generalize cosy.np.project_2d_to_3d
                            params["ego_to_world"] = cosy.np.project_2d_to_3d(scene[i].world_to_crs).inverse() * frames_ego_to_crs[i]
                            scene[i] = FrameId(**params)
            if len(finetuned_ego_to_crs) > 0:
                print(f"Found scenes in {finetuned_ego_to_crs_file} that were not in dataset {list(location_to_dataset.values())[0].name}: {list(finetuned_ego_to_crs.keys())}")
        return location_to_dataset
    return load

def cached_load(internal_load):
    func_signature = inspect.signature(internal_load)
    def load(path, cache=False, **kwargs):
        force_rebuild = isinstance(cache, str) and cache == "rebuild"
        if force_rebuild:
            cache = True
        if not isinstance(cache, bool):
            raise ValueError("Paramter cache must be bool or \"rebuild\"")
        if cache:
            # Get load parameters
            load_params = dict(kwargs)
            for k, v in func_signature.parameters.items():
                if k != "path" and k != "cache" and not k in load_params:
                    load_params[k] = v.default
            load_params = tuple(sorted([(key, value) for key, value in load_params.items()], key=lambda x: x[0]))

            # Load cache
            cache_file = os.path.join(path, "index-cache.pkl")
            if force_rebuild:
                cache = {}
            else:
                if os.path.isfile(cache_file):
                    with open(cache_file, "rb") as f:
                        cache = pickle.load(f)
                elif os.path.isfile(cache_file + ".gz"):
                    with gzip.open(cache_file + ".gz", "rb") as f:
                        cache = pickle.load(f)
                    with open(cache_file, "wb") as f:
                        pickle.dump(cache, f)
                else:
                    cache = {}

            if load_params in cache:
                # Retrieve result from cache
                result = cache[load_params]
            else:
                # Compute new result, save to cache
                result = internal_load(path, **kwargs)
                cache[load_params] = result
                with open(cache_file, "wb") as f:
                    pickle.dump(cache, f)

            return result
        else:
            return internal_load(path, **kwargs)
    return load

class Dataset:
    def __init__(self, name, location, scenes, image_sizes):
        self.name = name
        self.location = location
        self.scenes = scenes
        self.image_sizes = np.asarray(image_sizes)

    fullname = property(lambda self: f"{self.name}[{self.location}]")

    def __getstate__(self):
        scenes = FrameId.pickle_scenes(self.scenes)
        return (self.name, self.location, scenes, self.image_sizes)

    def __setstate__(self, state):
        self.name = state[0]
        self.location = state[1]
        self.image_sizes = state[3]
        self.scenes = FrameId.unpickle_scenes(state[2])

    def slice_frames_per_scene(self, slice):
        return Dataset(self.name, self.location, [s[slice] for s in self.scenes], self.image_sizes)

    def slice_scenes(self, slice):
        return Dataset(self.name, self.location, self.scenes[slice], self.image_sizes)
