import yaml, datetime, os, threading
import numpy as np

class Node:
    def __init__(self):
        self.dict = {}
        self.lock = threading.Lock()

    def __setitem__(self, key, value):
        with self.lock:
            if key in self.dict:
                raise ValueError(f"Node already contains a child with key {key}")
            self.dict[key] = value

    def overwrite(self, key, value):
        with self.lock:
            self.dict[key] = value

    def __getitem__(self, key):
        with self.lock:
            if not key in self.dict:
                self.dict[key] = Node()
            return self.dict[key]

    def __contains__(self, key):
        with self.lock:
            return key in self.dict

    def items(self):
        return self.dict.items()

    def OrderedPairs(self, key, *args, **kwargs):
        with self.lock:
            if not key in self.dict:
                self.dict[key] = OrderedPairs(*args, **kwargs)
            return self.dict[key]

    def TimeSeries(self, key, *args, **kwargs):
        with self.lock:
            if not key in self.dict:
                self.dict[key] = TimeSeries(*args, **kwargs)
            return self.dict[key]

    def commit(self, name):
        config = {}
        arrays = {}

        for k, v in self.dict.items():
            if "commit" in dir(v):
                v_config, v_arrays = v.commit(f"{name}/{k}")
                config[k] = v_config
                arrays.update(v_arrays)
            else:
                config[k] = v

        return config, arrays

class List:
    def __init__(self):
        self.values = []
        self.mean = None

    def add(self, value):
        self.values.append(value)
        if self.mean is None:
            self.mean = value
        else:
            self.mean += (value - self.mean) / len(self.values)

    def get(self):
        return self.values

class OrderedPairs:
    def __init__(self, y_type=None, ema_decay=None):
        if not (y_type is None or ema_decay is None):
            raise ValueError("Cannot pass both y_type and ema_decay")
        self.pairs = []
        self.y_type = y_type

        self.ema_decay = ema_decay
        if not ema_decay is None:
            self.ema = None

    def __getitem__(self, key):
        if self.y_type is None:
            raise ValueError("Must pass y_type in constructor when using __getitem__")
        if len(self.pairs) > 0 and self.pairs[-1][0] == key:
            return self.pairs[-1][1]
        else:
            y = self.y_type()
            self.pairs.append((key, y))
            return y

    def __setitem__(self, key, value):
        if not self.y_type is None:
            raise ValueError("Cannot use both y_type and __setitem__")
        self.pairs.append((key, value))
        if not self.ema_decay is None:
            if self.ema is None:
                self.ema = value
            else:
                self.ema = self.ema_decay * self.ema + (1 - self.ema_decay) * value

    def __iter__(self):
        return iter(self.pairs)

    def commit(self, name):
        config = {
            "xs": f"!mlog.arrays[{name}/xs]",
            "ys": f"!mlog.arrays[{name}/ys]",
        }
        arrays = {
            f"{name}/xs": np.asarray([p[0] for p in self.pairs]),
            f"{name}/ys": np.asarray([p[1] if self.y_type is None else p[1].get() for p in self.pairs]),
        }
        return config, arrays

class TimeSeries(OrderedPairs):
    def __init__(self, *args, **kwargs):
        OrderedPairs.__init__(self, *args, **kwargs)

    def add(self, value):
        self[np.datetime64(datetime.datetime.now())] = value

class Session(Node):
    def __init__(self, path):
        Node.__init__(self)

        self.commit_lock = threading.Lock()

        self.path = path
        if not os.path.isdir(self.path):
            os.makedirs(self.path)

        self["time"]["start"] = datetime.datetime.now().replace(microsecond=0).isoformat()

    def commit(self):
        self["time"].overwrite("last_commit", datetime.datetime.now().replace(microsecond=0).isoformat())
        with self.commit_lock:
            config, arrays = Node.commit(self, "")
            with open(os.path.join(self.path, "config.yaml"), "w") as f:
                yaml.dump(config, f, default_flow_style=False)
            np.savez(os.path.join(self.path, "arrays.npz"), **arrays)
