from . import stream, util
import threading, collections

class save_(stream.Stream):
    class Annotation:
        def __init__(self, constructor, sequence_index):
            self.constructor = constructor
            self.lock = threading.Lock()
            self.sequence_index = sequence_index
            self.num = 1

        def drop(self):
            with self.lock:
                self.num -= 1
                assert self.num >= 0

        def new(self, new_item):
            with self.lock:
                self.num += 1
                new_item[util.id(self.constructor)] = self

        def dec(self):
            with self.lock:
                self.num -= 1
                assert self.num >= 0
                if self.num == 0:
                    assert self.sequence_index == self.constructor.next_output_sequence_index

    def __init__(self, input):
        super().__init__()
        self.input = stream.wrap(input)
        self.sequence_elements = []
        self.next_input_sequence_index = 0
        self.next_output_sequence_index = 0
        self.input_lock = threading.Lock()
        self.output_lock = threading.Lock()

    def stop(self):
        self.input.stop()

    def next(self):
        item = stream.next(self.input)
        with self.input_lock:
            new_annotation = save_.Annotation(self, self.next_input_sequence_index)
            with self.output_lock:
                self.sequence_elements.append(new_annotation)
            item[util.id(self)] = new_annotation
            self.next_input_sequence_index += 1
            return item

    def getNextOutputSequenceIndex(self):
        # Remove all sequence elements without items, return the next sequence index
        with self.output_lock:
            while len(self.sequence_elements) > 0:
                with self.sequence_elements[0].lock:
                    if self.sequence_elements[0].num == 0:
                        del self.sequence_elements[0]
                        self.next_output_sequence_index += 1
                    else:
                        break
            return self.next_output_sequence_index

def save(*args, **kwargs):
    result = save_(*args, **kwargs)
    return result, result # Interpreted as (stream, order) so that stream can be used for continued construction and order kept until order.load is called

class load(stream.Stream):
    def __init__(self, input, order):
        super().__init__()
        self.input = stream.wrap(input)
        self.order = order
        self.lock = threading.Lock()
        self.items_by_sequence_index = collections.defaultdict(list)

    def stop(self):
        self.input.stop()

    def next(self):
        while True:
            # Check if item from next sequence element is already queued
            with self.lock:
                next_output_sequence_index = self.order.getNextOutputSequenceIndex()
                if next_output_sequence_index in self.items_by_sequence_index and len(self.items_by_sequence_index[next_output_sequence_index]) > 0:
                    # Retrieve stored item
                    old_item = self.items_by_sequence_index[next_output_sequence_index].pop(0)
                    # Tell sequence element that item was used up
                    annotation = old_item[util.id(self.order)]
                    annotation.dec()
                    # Remove stored list of items for this index if empty
                    if len(self.items_by_sequence_index[next_output_sequence_index]) == 0:
                        del self.items_by_sequence_index[next_output_sequence_index]
                    # Return item
                    # print("ORDER OUT: Returning stored item for sequence element " + str(next_output_sequence_index) + " " + str(type(old_item.get())) + ", expecting " + str(annotation.num) + " more")
                    return old_item
            # No queued item is next in sequence, retrieve new item from input stream
            new_item = stream.next(self.input)
            with self.lock:
                next_output_sequence_index = self.order.getNextOutputSequenceIndex()
                annotation = new_item[util.id(self.order)]
                if next_output_sequence_index == annotation.sequence_index:
                    # print("ORDER OUT: Returning new item for sequence element " + str(next_output_sequence_index) + " " + str(type(new_item.get())) + ", expecting " + str(annotation.num) + " more")
                    # New item is next in sequence
                    # Tell sequence element that an item was used up
                    annotation.dec()
                    # Return item
                    return new_item
                else:
                    # print("ORDER OUT: Storing item for sequence element " + str(annotation.sequence_index) + " " + str(type(new_item.get())) + ", now expecting " + str(annotation.num))
                    # New item is not next in sequence, store for later
                    self.items_by_sequence_index[annotation.sequence_index].append(new_item)
