import tqdm, skimage.draw, skimage.transform, threading, cv2, cosy
import numpy as np
from collections import defaultdict
import tinypl as pl

def draw_line(image, t1, t2, color, thickness=1):
    if not t1 is None and not t2 is None:
        color = np.asarray(color)
        rr, cc = skimage.draw.line(int(t1[0]), int(t1[1]), int(t2[0]), int(t2[1]))
        for ri in range(-thickness + 1, thickness):
            for ci in range(-thickness + 1, thickness):
                rri = rr + ri
                cci = cc + ci
                in_bounds = np.logical_and(np.logical_and(0 <= rri, rri < image.shape[0]), np.logical_and(0 <= cci, cci < image.shape[1]))
                rri = rri[in_bounds]
                cci = cci[in_bounds]
                image[rri, cci] = color

def draw_points(image, points, color, radius=0.6):
    points = np.asarray(points)
    if len(points.shape) == 1:
        points = points[np.newaxis, :]
    color = np.asarray(color)
    if len(color.shape) == 1:
        color = np.repeat(color[np.newaxis, :], axis=0, repeats=points.shape[0])

    xxc, yyc = np.mgrid[:2 * radius, :2 * radius]
    circle = (xxc - radius) ** 2 + (yyc - radius) ** 2 < radius ** 2
    xxc = xxc[circle] - radius
    yyc = yyc[circle] - radius

    pixels_per_circle = xxc.shape[0]
    points_num = points.shape[0]

    # Add circle coordinates to every point
    color = np.repeat(color, pixels_per_circle, axis=0)
    points = np.repeat(points, pixels_per_circle, axis=0) + np.tile(np.stack([xxc, yyc], axis=1), [points_num, 1])

    # Filter out of bounds points
    mask = np.all(np.logical_and(0 <= points, points < np.asarray(image.shape[:2])[np.newaxis, :]), axis=-1)
    color = color[mask, :]
    points = points[mask, :]

    points = points.astype("int32")
    if color.shape[1] == 3:
        image[points[:, 0], points[:, 1]] = color
    elif color.shape[1] == 4:
        image[points[:, 0], points[:, 1]] = np.clip(color[:, :3].astype("float32") * color[:, 3:] / 255.0 + image[points[:, 0], points[:, 1]].astype("float32") * (255 - color[:, 3:]) / 255.0, 0.0, 255.0).astype("uint8")
    else:
        raise ValueError(f"Color must have 3 or 4 channels, found {color.shape[1]}")

def draw_trajectories(latlons, tile_loader, zoom, bearings=None, colors=None, tile_padding=1, downsample=1, bearing_length=2.0, bearing_stride=5, return_min_pixel=False, verbose=False, sync_tile_loader=False):
    lock = threading.Lock()
    images = {}
    positions = defaultdict(dict)
    if not bearings is None:
        positions2 = defaultdict(dict)

    if sync_tile_loader:
        tile_loader_lock = threading.Lock()

    input = [(scene_index, frame_index) for scene_index in range(len(latlons)) for frame_index in range(len(latlons[scene_index]))]
    stream = iter(input)
    stream = pl.sync(stream)
    @pl.unpack
    def load(scene_index, frame_index):
        latlon = latlons[scene_index][frame_index]

        center_tile = tile_loader.layout.epsg4326_to_tile(latlon, zoom=zoom).astype("int64")
        for x in range(-tile_padding, tile_padding + 1):
            for y in range(-tile_padding, tile_padding + 1):
                tile = center_tile + np.asarray([x, y])
                if not (tile[1], tile[0]) in images:
                    if sync_tile_loader:
                        with tile_loader_lock:
                            tile_image = tile_loader.load(tile, zoom=zoom)
                    else:
                        tile_image = tile_loader.load(tile, zoom=zoom)
                    if downsample != 1:
                        tile_image = skimage.transform.downscale_local_mean(tile_image, (downsample, downsample, 1))
                    with lock:
                        images[(tile[1], tile[0])] = tile_image

        position = (tile_loader.layout.epsg4326_to_pixel(latlon, zoom=zoom) // downsample).astype("int64")
        with lock:
            positions[scene_index][frame_index] = position

        if not bearings is None:
            bearing = bearings[scene_index][frame_index]
            latlon2 = cosy.np.geo.move_from_latlon(latlon, bearing, distance=bearing_length)
            position2 = (tile_loader.layout.epsg4326_to_pixel(latlon2, zoom=zoom) // downsample).astype("int64")
            with lock:
                positions2[scene_index][frame_index] = position2
    stream = pl.map(load, stream)
    stream = pl.queued(stream, workers=4, maxsize=4)
    if verbose:
        stream = tqdm.tqdm(stream, total=len(input))
    for _ in stream:
        pass

    # Construct image
    tile_shape = tile_loader.layout.tile_shape // downsample
    tile_coords = np.asarray(list(images.keys()))
    tile_min = np.amin(tile_coords, axis=0)
    tile_max = np.amax(tile_coords, axis=0)
    tile_num = (tile_max - tile_min) + 1
    image_shape = tile_num * tile_shape
    image = np.zeros((image_shape[0], image_shape[1], 3), dtype="uint8")
    for tile, subimage in images.items():
        start = ((tile - tile_min) * tile_shape).astype("int32")
        end = start + tile_shape
        image[start[0]:end[0], start[1]:end[1]] = subimage
    min_pixel = tile_min * tile_shape

    # Draw on image
    for scene_index in range(len(latlons)):
        # Trajectory
        positions_scene = [position for frame_index, position in sorted(positions[scene_index].items())]
        positions_scene = np.asarray(positions_scene) - min_pixel
        for i in range(positions_scene.shape[0] - 1):
            color = colors[scene_index][i] if not colors is None else np.asarray([255, 0, 0])
            draw_line(image, positions_scene[i], positions_scene[i + 1], color, thickness=2)

        # Bearings
        if not bearings is None:
            positions2_scene = [position2 for frame_index, position2 in sorted(positions2[scene_index].items())]
            positions2_scene = np.asarray(positions2_scene) - tile_min * tile_shape
            color = np.asarray([0, 0, 255])
            for i in range(0, positions2_scene.shape[0] - 1, bearing_stride):
                draw_line(image, positions_scene[i], positions2_scene[i], color, thickness=1)

    results = [image]
    if return_min_pixel:
        results.append(min_pixel)
    return tuple(results) if len(results) > 1 else results[0]
