import numpy as np
import os, imageio, tfcv
from collections import defaultdict

colorhash_to_name = {
    "#ff0000": "Car 1",
    "#c80000": "Car 2",
    "#960000": "Car 3",
    "#800000": "Car 4",
    "#b65906": "Bicycle 1",
    "#963204": "Bicycle 2",
    "#5a1e01": "Bicycle 3",
    "#5a1e1e": "Bicycle 4",
    "#cc99ff": "Pedestrian 1",
    "#bd499b": "Pedestrian 2",
    "#ef59bf": "Pedestrian 3",
    "#ff8000": "Truck 1",
    "#c88000": "Truck 2",
    "#968000": "Truck 3",
    "#00ff00": "Small vehicles 1",
    "#00c800": "Small vehicles 2",
    "#009600": "Small vehicles 3",
    "#0080ff": "Traffic signal 1",
    "#1e1c9e": "Traffic signal 2",
    "#3c1c64": "Traffic signal 3",
    "#00ffff": "Traffic sign 1",
    "#1edcdc": "Traffic sign 2",
    "#3c9dc7": "Traffic sign 3",
    "#ffff00": "Utility vehicle 1",
    "#ffffc8": "Utility vehicle 2",
    "#e96400": "Sidebars",
    "#6e6e00": "Speed bumper",
    "#808000": "Curbstone",
    "#ffc125": "Solid line",
    "#400040": "Irrelevant signs",
    "#b97a57": "Road blocks",
    "#000064": "Tractor",
    "#8b636c": "Non-drivable street",
    "#d23273": "Zebra crossing",
    "#ff0080": "Obstacles / trash",
    "#fff68f": "Poles",
    "#960096": "RD restricted area",
    "#ccff99": "Animals",
    "#eea2ad": "Grid structure",
    "#212cb1": "Signal corpus",
    "#b432b4": "Drivable cobblestone",
    "#ff46b9": "Electronic traffic",
    "#eee9bf": "Slow drive area",
    "#93fdc2": "Nature object",
    "#9696c8": "Parking area",
    "#b496c8": "Sidewalk",
    "#48d1cc": "Ego car",
    "#c87dd2": "Painted driv. instr.",
    "#9f79ee": "Traffic guide obj.",
    "#8000ff": "Dashed line",
    "#ff00ff": "RD normal street",
    "#87ceff": "Sky",
    "#f1e6ff": "Buildings",
    "#60458f": "Blurred area",
    "#352e52": "Rain dirt",
}

colorid_to_name = {}
colorid_to_uniquename = {}
colorid_to_color = {}
name_to_uniquename = {}
uniquename_to_names = defaultdict(list)
for colorhash, name in colorhash_to_name.items():
    r = int(colorhash[1:3], 16)
    g = int(colorhash[3:5], 16)
    b = int(colorhash[5:7], 16)
    colorid = int(r + 256 * g + 256 * 256 * b)
    colorid_to_name[colorid] = name
    colorid_to_color[colorid] = (r, g, b)

    uniquename = name[:-2] if name.split(" ")[-1].isnumeric() else name
    colorid_to_uniquename[colorid] = uniquename
    name_to_uniquename[name] = uniquename
    uniquename_to_names[uniquename].append(name)
uniquename_to_colorid = {uniquename: colorid for colorid, uniquename in colorid_to_uniquename.items()}
colorids = list(colorid_to_name.keys())

id_to_name = sorted(colorid_to_name.values())
name_to_id = {name: id for id, name in enumerate(id_to_name)}

colorid_to_id = {color_id: name_to_id[name] for color_id, name in colorid_to_name.items()}
color_to_id = {colorid_to_color[colorid]: colorid_to_id[colorid] for colorid in colorids}

id_to_uniquename = sorted(colorid_to_uniquename.values())
uniquename_to_id = {name: id for id, name in enumerate(id_to_uniquename)}

names = list(name_to_id.keys())
uniquenames = list(uniquename_to_id.keys())

uniquename_to_color = {uniquename: colorid_to_color[uniquename_to_colorid[uniquename]] for uniquename in uniquenames}

trainid_to_uniquename = sorted([n for n in uniquenames if n != "Rain dirt" and n != "Blurred area" and n != "Ego car"])
uniquename_to_trainid = {n: (trainid_to_uniquename.index(n) if n in trainid_to_uniquename else -1) for n in uniquenames}
trainids = range(len(trainid_to_uniquename))

colorid_to_trainid = {}
for colorid in colorids:
    uniquename = colorid_to_uniquename[colorid]
    if uniquename in uniquename_to_trainid:
        trainid = uniquename_to_trainid[uniquename]
    else:
        trainid = 255
    colorid_to_trainid[colorid] = trainid

id_to_trainid = [uniquename_to_trainid[uniquename] for uniquename in id_to_uniquename]


class Frame:
    def __init__(self, image, labels, frame_id):
        self.image = image
        self.labels = labels
        self.id = frame_id

class FrameId:
    def __init__(self, image_file, labels_file, cam, sequence, subset):
        self.image_file = image_file
        self.labels_file = labels_file
        self.cam = cam
        self.sequence = sequence
        self.subset = subset

    def load(self):
        if not self.image_file is None:
            image = imageio.imread(self.image_file)[:, :, :3]
        else:
            image = None

        if not self.labels_file is None:
            labels = imageio.imread(self.labels_file)

            original_shape = labels.shape
            if len(original_shape) == 3:
                if original_shape[2] != 3 and original_shape[2] != 4:
                    raise IOError(f"Got invalid shape {original_shape} for labels file {self.labels_file}")

                labels = tfcv.image.color_to_class(labels[:, :, :3], color_to_class=color_to_id)
        else:
            labels = None

        return Frame(image, labels, self)

def load(path):
    frame_ids = []
    for subset in os.listdir(path):
        subset_path = os.path.join(path, subset)
        if os.path.isdir(subset_path):
            for sequence in os.listdir(subset_path):
                sequence_path = os.path.join(subset_path, sequence)
                if os.path.isdir(sequence_path):
                    sequence_files = defaultdict(lambda: defaultdict(defaultdict))
                    for sensor in ["camera", "label"]: # "lidar"
                        sensor_path = os.path.join(sequence_path, sensor)
                        if os.path.isdir(sensor_path):
                            for cam in os.listdir(sensor_path):
                                cam_path = os.path.join(sensor_path, cam)
                                if os.path.isdir(cam_path):
                                    for file in os.listdir(cam_path):
                                        if file.endswith(".png"):
                                            name = ".".join(file.split(".")[:-1])
                                            name = name.split("_")
                                            name = "_".join(name[:1] + name[2:])
                                            file = os.path.join(cam_path, file)
                                            sequence_files[name][cam][sensor] = file
                    for name, cams in sequence_files.items():
                        for cam, sensors in cams.items():
                            frame_ids.append(FrameId(sensors["camera"] if "camera" in sensors else None, sensors["label"] if "label" in sensors else None, cam, sequence, subset))
    frame_ids = sorted(frame_ids, key=lambda frame_id: (frame_id.image_file if not frame_id.image_file is None else frame_id.label_file))
    return frame_ids





uniquename_to_cityscapesname = {
    "Animals": "unlabeled",
    "Bicycle": "bicycle",
    "Blurred area": "unlabeled",
    "Buildings": "building",
    "Car": "car",
    "Curbstone": "sidewalk",
    "Dashed line": "rider",
    "Drivable cobblestone": "road",
    "Ego car": "unlabeled",
    "Electronic traffic": "unlabeled",
    "Grid structure": "fence",
    "Irrelevant signs": "unlabeled",
    "Nature object": "vegetation",
    "Non-drivable street": "terrain",
    "Obstacles / trash": "unlabeled",
    "Painted driv. instr.": "rider",
    "Parking area": "bus",
    "Pedestrian": "person",
    "Poles": "pole",
    "Rain dirt": "unlabeled",
    "RD normal street": "road",
    "RD restricted area": "unlabeled",
    "Road blocks": "wall",
    "Sidebars": "wall",
    "Sidewalk": "sidewalk",
    "Signal corpus": "pole",
    "Sky": "sky",
    "Slow drive area": "bus",
    "Small vehicles": "motorcycle",
    "Solid line": "rider",
    "Speed bumper": "road",
    "Tractor": "truck",
    "Traffic guide obj.": "wall",
    "Traffic sign": "traffic sign",
    "Traffic signal": "traffic light",
    "Truck": "truck",
    "Utility vehicle": "train",
    "Zebra crossing": "rider",
}
from . import cityscapes
id_to_cityscapes_trainid = [cityscapes.name2label[uniquename_to_cityscapesname[uniquename]].trainId for uniquename in id_to_uniquename]
assert len(set(id_to_cityscapes_trainid)) - 1 == cityscapes.trainid_num
