import os, yaml, re, time, tqdm
import numpy as np

class Database:
    def __init__(self, path):
        self.path = path
        if not os.path.isdir(path):
            os.makedirs(path)

    def get_all(self):
        l = [os.path.join(self.path, n) for n in os.listdir(self.path) if n != "mlog"]
        l = [p for p in l if os.path.isdir(p)]
        l2 = []
        for p in tqdm.tqdm(l):
            l2.append(Experiment(p))
        return l2

class Experiment:
    def __init__(self, path):
        self.path = path
        try:
            with open(os.path.join(self.path, "mlog", "config.yaml"), "r") as f:
                self.config = yaml.safe_load(f)
        except FileNotFoundError:
            self.config = {}
        if self.config is None:
            self.config = {}
        self.name = os.path.basename(self.path)

        arrays = {}
        for _ in range(10):
            try:
                with np.load(os.path.join(self.path, "mlog", "arrays.npz")) as data:
                    arrays = dict(data)
                break
            except:
                time.sleep(0.1)
                continue

        try:
            with open(os.path.join(self.path, "log.yaml"), "r") as f:
                log = yaml.safe_load(f)
            if "server" in log and "gnu-screen" in log:
                self.config["server"] = log["server"] + " " + log["gnu-screen"].split(".")[1]
        except FileNotFoundError:
            pass

        def process(node):
            if isinstance(node, list):
                for n in node:
                    process(n)
            elif isinstance(node, dict):
                keys = list(node.keys())
                for k in keys:
                    v = node[k]
                    if isinstance(v, str):
                        match = re.match(re.escape("!mlog.arrays[") + "(.*)" + re.escape("]"), v)
                        if match:
                            name = match.group(1)
                            try:
                                if name in arrays:
                                    node[k] = np.asarray(arrays[name])
                                else:
                                    del node[k]
                            except ValueError as e:
                                raise ValueError(f"Failed to load {name}") from e
                    else:
                        process(v)
        try:
            process(self.config)
        except ValueError as e:
            raise ValueError(f"Failed to load config in {self.path}") from e

    def __getitem__(self, key):
        return self.config[key]

    def __contains__(self, key):
        return key in self.config

# TODO: remove
# import numpy as np
# experiments = []
#
# xs = np.linspace(-30.0, 30.0, num=600)
# experiments.append({
#     "name": "experiment1",
#     "xs": xs,
#     "ys": np.sin(xs / 3),
# })
#
# xs = np.linspace(-30.0, 30.0, num=600)
# experiments.append({
#     "name": "experiment2",
#     "xs": xs,
#     "ys": np.sin(xs),
# })
#
# def get_all():
#     return experiments
