import scipy.ndimage, skimage.transform, math
import numpy as np
from . import param, op
# TODO: some of these are with tensorflow, others with numpy

def rotate(angle, axes=(1, 0), ndim=2):
    def call(images, types, angle=angle, axes=axes, ndim=ndim):
        angle = param.array(angle, name="angle", dtype=np.float32, shape=())
        axes = param.array(axes, name="axes", dtype=np.int32, shape=(2,))
        if ndim == 2:
            if np.any(axes >= 2):
                raise ValueError(f"Invalid axes {axes}")
            if axes[0] == 0:
                angle = -angle

        for i in range(len(images)):
            dtype = images[i].dtype
            if ndim == 2:
                images[i] = skimage.transform.rotate(images[i].astype("float32"), math.degrees(angle), order=types[i].interpolation_order, mode="constant", cval=types[i].border_value, preserve_range=True)
            elif ndim == 3:
                if i.ndim != 3:
                    raise ValueError(f"Cannot perform 3D rotation on array with shape {i.shape}")
                # images[i] = scipy.ndimage.rotate(images[i], math.degrees(angle), axes=axes, order=types[i].interpolation_order, mode="constant", cval=types[i].border_value, preserve_range=True)
                assert False, "Test this first"
            images[i] = images[i].astype(dtype)

        return images

    return op.op(call)

def flip(axis=None, probability=1.0):
    def call(images, types, axis=axis):
        if np.random.uniform(0.0, 1.0) < probability:
            if axis is None:
                ndim = min([i.ndim for i in images])
                axis = np.random.randint(0, ndim)
            axis = param.array(axis, name="axis", dtype=np.int32, shape=())
            for i in range(len(images)):
                images[i] = np.flip(images[i], axis=axis)
        return images

    return op.op(call)

def flip_lr(probability=1.0):
    return flip(axis=1, probability=probability)

def resize_to(size, ndim=2):
    def call(images, types, size=size):
        size = param.dims(size, name="size", dtype=np.int32, images=images, ndim=ndim)

        for i in range(len(images)):
            dtype = images[i].dtype
            images[i] = skimage.transform.resize(images[i].astype("float32"), size, order=types[i].interpolation_order, mode="constant", cval=types[i].border_value, preserve_range=True, anti_aliasing=types[i].interpolation_order > 0)
            images[i] = images[i].astype(dtype)
        return images

    return op.op(call)

def resize_by(factor, ndim=2):
    def call(images, types, factor=factor):
        factor = param.dims(factor, name="factor", dtype=np.float32, images=images, ndim=ndim)
        if factor.shape[0] > 1:
            factor = tuple(factor.tolist())

        for i in range(len(images)):
            dtype = images[i].dtype
            images[i] = skimage.transform.rescale(images[i].astype("float32"), factor, order=types[i].interpolation_order, multichannel=ndim != images[i].ndim, preserve_range=True, anti_aliasing=types[i].interpolation_order > 0)
            images[i] = images[i].astype(dtype)
        return images

    return op.op(call)

def crop(size, position=None, ndim=2):
    def call(images, types, size=size, position=position):
        output_shape = param.dims(size, name="size", dtype=np.int32, images=images, ndim=ndim)
        input_shape = images[0].shape[:ndim]
        output_shape = np.minimum(output_shape, input_shape)

        if position is None:
            position = np.random.uniform(0.0, 1.0, (ndim,))
        elif isinstance(position, str):
            if position == "center":
                position = 0.5
            else:
                raise ValueError(f"Got invalid position argument {position}")
        position = param.dims(position, name="position", ndim=ndim)

        if isinstance(position, float) or (isinstance(position, np.ndarray) and np.issubdtype(position.dtype, np.float)):
            position = ((input_shape - output_shape) * position).astype("int32")

        for i in range(len(images)):
            images[i] = images[i][tuple([slice(p, p + o) for p, o in zip(position, output_shape)])]

        return images

    return op.op(call)

def pad(size, location="center", ndim=2):
    def call(images, types, size=size):
        output_shape = param.dims(size, name="size", dtype=np.int32, images=images, ndim=ndim)
        input_shape = images[0].shape[:ndim]
        output_shape = np.maximum(output_shape, input_shape)

        if location == "center":
            pad_width = np.asarray(output_shape) - np.asarray(input_shape)
            pad_width_front = pad_width // 2
            pad_width_back = pad_width - pad_width_front
            pad_width = [(pad_width_front[i], pad_width_back[i]) for i in range(len(output_shape))]
        elif location == "topleft":
            pad_width = np.asarray(output_shape) - np.asarray(input_shape)
            pad_width = [(0, pad_width[i]) for i in range(len(output_shape))]
        else:
            raise ValueError(f"Invalid location argument {location}")

        for i in range(len(images)):
            pad_width_i = np.copy(pad_width).tolist()
            for _ in range(len(images[i].shape) - 2):
                pad_width_i.append((0, 0))
            images[i] = np.pad(images[i], pad_width=tuple(pad_width_i), mode="constant", constant_values=types[i].border_value)

        return images

    return op.op(call)
