import itertools, random, skimage.transform, tfcv, tinyobserver, math, imageio, os, collections, cosy
import tinypl as pl
import numpy as np
import georegdata as grd
from collections import defaultdict

class AerialSampler_(collections.abc.Sequence):
    EARTH_RADIUS_METERS = 6.378137e6

    def __init__(self, frame_ids, meters_per_chunk, order="random-uniform"):
        super().__init__()
        self.chunks = defaultdict(list)
        self.order = order

        for frame_id in frame_ids:
            latlon = frame_id.latlon

            chunk_y = int(AerialSampler_.EARTH_RADIUS_METERS * math.radians(latlon[0]) / meters_per_chunk)
            chunk_lat = chunk_y * meters_per_chunk / AerialSampler_.EARTH_RADIUS_METERS
            circumference_at_lat = AerialSampler_.EARTH_RADIUS_METERS * math.cos(math.radians(chunk_lat))
            chunk_x = int(circumference_at_lat * math.radians(latlon[1]) / meters_per_chunk)

            self.chunks[(chunk_y, chunk_x)].append(frame_id)
        self.chunks = [v for k, v in sorted(self.chunks.items())]
        if "score" in dir(self.chunks[0][0]):
            self.chunks = [sorted(chunk, key=lambda w: w.score if not w.score is None else 999999.0) for chunk in self.chunks]

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, key):
        chunk = self.chunks[key]
        if self.order == "random-uniform":
            return np.random.choice(chunk)
        elif self.order == "random-weighted":
            p = np.asarray([w.score for w in chunk])
            p = p / np.sum(p)
            return np.random.choice(chunk, p=p)
        elif self.order == "hardest":
            return chunk[-1]
        elif self.order == "easiest":
            return chunk[0]
        else:
            assert False

def AerialSampler(frame_ids, meters_per_chunk, order="random-uniform"):
    if meters_per_chunk == 0:
        return [f for f in frame_ids]
    else:
        return AerialSampler_(frame_ids, meters_per_chunk, order)

def augment_aerial_image(image, get_epoch=None, color=True, shadow=0):
    rng = np.random.default_rng(np.prod(image.astype("int32")) + (get_epoch() if not get_epoch is None else 0))
    if color:
        image = tfcv.image.multiply(lambda: np.clip(rng.normal(loc=0.0, scale=1.0, size=(3,)) * np.asarray([0.0, 0.05, 0.05]), -0.2, 0.2) + 1.0, colorspace="hsv")(image)
        image = tfcv.image.add(lambda: np.clip(rng.normal(loc=0.0, scale=0.03, size=(3,)), -0.1, 0.1) * np.asarray([255.0, 0.0, 0.0]), colorspace="hsv", out_of_bounds="repeat")(image)
        image = tfcv.image.blur_sharpen(amount=lambda: rng.uniform(-0.3, 0.3), std=4.0)(image)
        image = tfcv.image.add_gaussian_noise(std=lambda: rng.uniform(0.0, 0.01) * 255.0)(image)

    if shadow > 0:
        assert False
        max_num = shadow_augment
        min_rectangle_shape = 0.1
        max_rectangle_shape = 0.7
        brightness = random.uniform(0.2, 0.5)
        blue_factor = random.uniform(0.7, 0.9)
        sigma = random.uniform(0.8, 2.0)

        num = int(random.random() * (max_num + 1))
        shape = np.asarray(aerial_frame.image.shape[:2])

        mask = np.zeros(shape=aerial_frame.image.shape[:2], dtype="bool")
        for _ in range(num):
            # rectangle_shape = np.clip(np.random.normal(0.0, 0.2, (2,)), min_rectangle_shape, max_rectangle_shape)
            rectangle_shape = np.random.uniform(min_rectangle_shape, max_rectangle_shape, (2,)) * shape
            rectangle_vertices = np.asarray([
                np.asarray([+rectangle_shape[0], +rectangle_shape[1]]) / 2,
                np.asarray([-rectangle_shape[0], +rectangle_shape[1]]) / 2,
                np.asarray([-rectangle_shape[0], -rectangle_shape[1]]) / 2,
                np.asarray([+rectangle_shape[0], -rectangle_shape[1]]) / 2,
            ])

            angle = random.uniform(0.0, 2 * math.pi)
            rectangle_vertices = cosy.np.Rigid(rotation=cosy.np.angle_to_rotation_matrix(angle))(rectangle_vertices)

            location = (np.random.uniform(0.0, 1.0, (2,)) * shape).astype("int32")
            rectangle_vertices = location + rectangle_vertices

            rectangle_mask = skimage.draw.polygon2mask(shape, rectangle_vertices)

            mask = np.logical_or(mask, rectangle_mask)

        factor = brightness
        factor = np.asarray([factor, factor, 1 - (1 - factor) * blue_factor])

        factor = np.where(mask[:, :, np.newaxis], factor[np.newaxis, np.newaxis, :], 1.0)
        import scipy.ndimage
        factor = scipy.ndimage.gaussian_filter(factor, sigma=(sigma, sigma, 0.0))

        image = factor * image
    return image

def augment_ground_image(image, get_epoch=None, color=True):
    rng = np.random.default_rng(np.prod(image.astype("int32")) + (get_epoch() if not get_epoch is None else 0))
    if color:
        image = tfcv.image.multiply(lambda: np.clip(rng.normal(loc=0.0, scale=1.0, size=(3,)) * np.asarray([0.0, 0.05, 0.05]), -0.2, 0.2) + 1.0, colorspace="hsv")(image)
        image = tfcv.image.add(lambda: np.clip(rng.normal(loc=0.0, scale=0.03, size=(3,)), -0.1, 0.1) * np.asarray([255.0, 0.0, 0.0]), colorspace="hsv", out_of_bounds="repeat")(image)
        image = tfcv.image.blur_sharpen(amount=lambda: rng.uniform(-0.3, 0.3), std=4.0)(image)
        image = tfcv.image.add_gaussian_noise(std=lambda: rng.uniform(0.0, 0.01) * 255.0)(image)
    return image

def stream(frame_ids, frames_to_model_input, frames_to_loss_input, preprocess=None, batchsize=1, split=None, track_event=None, mlog_session=None, save_first_num=0, save_first_path=None, workers=[12, 12]):
    if save_first_num > 0 and save_first_path is None:
        raise ValueError("Expected save_first_path when save_first_num > 0")
    if (not mlog_session is None and not split is None) and track_event is None:
        track_event = tinyobserver.clock.Thread(interval=2.0)

    workers = [((w, w) if isinstance(w, int) else w) for w in workers]

    fills = []
    def add_fill_report(stream, name):
        fills.append((name, stream.getFill))
        if not track_event is None and not mlog_session is None and not split is None:
            track_event.subscribe(lambda stream=stream, name=name, index=len(fills): mlog_session["data_processing"][split].TimeSeries(f"Q{index}.{name}").add(stream.getFill()))

    stream = iter(frame_ids)

    stream, epoch_end_marker = pl.marker.after_each(stream)
    # stream, epoch_order = pl.order.save(stream)
    stream = pl.flatten(stream)
    def add_save_flag(frame_id):
        nonlocal save_first_num
        if save_first_num > 0:
            save_first_num -= 1
            save = True
        else:
            save = False
        return frame_id, save
    stream = pl.map(add_save_flag, stream)
    stream, iterations_order = pl.order.save(stream)
    stream = pl.sync(stream) # Concurrent processing starts after this point


    @pl.unpack
    def load_frame(frame_id, save):
        if not preprocess is None:
            frame_id = preprocess(frame_id)
        frame = frame_id.load()
        if save:
            imageio.imwrite(os.path.join(save_first_path, f"{frame.name}-aerial.jpg"), frame.visualize().astype("uint8"))
            for camera in frame.ground_frame.cameras:
                imageio.imwrite(os.path.join(save_first_path, f"{frame.name}-ground-{camera.name}.jpg"), camera.visualize().astype("uint8"))
        return frame
    stream = pl.map(load_frame, stream)
    stream = pl.queued(stream, workers=workers[0][0], maxsize=workers[0][1])
    add_fill_report(stream, "frame")

    stream = pl.order.load(stream, iterations_order) # epoch_order

    stream = pl.partition(batchsize, stream, markers="after", marker_splits_batch=True)
    stream, batch_order = pl.order.save(stream)
    stream = pl.sync(stream)

    def process_batch(frames):
        model_input = frames_to_model_input(frames)
        loss_input = frames_to_loss_input(frames)
        return frames, model_input, loss_input
    stream = pl.map(process_batch, stream)
    stream = pl.queued(stream, workers=workers[1][0], maxsize=workers[1][1])
    add_fill_report(stream, "batch")
    stream = pl.order.load(stream, batch_order)

    return stream, epoch_end_marker, fills
