import tensorflow as tf
import numpy as np
import cosy, math, georeg, tfcv

def error_volume(new_to_gt, image_shape, angles): # TODO: remove
    new_to_gt_translation = cosy.tf.Rigid.translation(new_to_gt)
    new_to_gt_rotation = cosy.tf.Rigid.rotation(new_to_gt)

    axy = axy_volume(image_shape, angles=angles, dtype=new_to_gt_translation.dtype) # [batches, angles, dims..., axy]

    translation_error_volume = tf.math.sqrt(tf.reduce_sum(tf.math.square(axy[..., 1:] - new_to_gt_translation[:, tf.newaxis, tf.newaxis, tf.newaxis, :]), axis=-1)) # [batch, angles, dims...]
    rotation_error_volume = tf.math.abs(axy[..., 0] - cosy.tf.rotation_matrix_to_angle(new_to_gt_rotation)[:, tf.newaxis, tf.newaxis, tf.newaxis]) # [batch, angles, dims...]
    rotation_error_volume = tf.where(rotation_error_volume > math.pi, 2 * math.pi - rotation_error_volume, rotation_error_volume)

    translation_error_volume = translation_error_volume / tf.norm(tf.cast(image_shape, "float32"))
    max_rotation_error = tf.reduce_max(rotation_error_volume)
    rotation_error_volume = tf.where(max_rotation_error > 0, rotation_error_volume / max_rotation_error, 0.0)

    return translation_error_volume + rotation_error_volume

# Numerically stable spatial softmax
def softmax(corr_logits, valid_corr):
    corr = corr_logits
    s = tf.shape(corr)[-2:]

    corr = tfcv.model.einops.apply("b... a s... -> b... (a s...)", corr, s=s)
    valid_corr = tfcv.model.einops.apply("b... a s... -> b... (a s...)", valid_corr, s=s)
    corr = tf.where(valid_corr, corr, 0.0)

    numerical_offset = tf.reduce_max(tf.where(valid_corr, corr, tf.reduce_min(corr, axis=-1, keepdims=True)), axis=-1, keepdims=True)
    corr = corr - numerical_offset
    corr = tf.math.exp(corr)
    corr = tf.where(valid_corr, corr, 0.0)
    corr = tf.math.divide_no_nan(corr, tf.reduce_sum(corr, axis=-1, keepdims=True))

    corr = tfcv.model.einops.apply("b... (a s...) -> b... a s...", corr, s=s)

    return corr



def gather_window_2d(y, valid_y, window_shape):
    # y: b s...

    b = tf.shape(y)[0]
    s = tf.shape(y)[-2:]

    if isinstance(window_shape, int):
        window_shape = np.asarray([window_shape, window_shape])
    window_shape = window_shape.astype("int32")
    window_left_half = window_shape // 2
    window_right_half = window_shape - window_left_half

    argmax = tf.cast(argmax2d(y), "int32")
    argmax = tf.clip_by_value(argmax, window_left_half, s - window_right_half)
    start = argmax - window_left_half[tf.newaxis, :] # b 2

    window_shape = window_shape * (tf.shape(y)[0] * 0 + 1)
    window_indices = tf.stack([
        tfcv.model.einops.apply("x -> (x y)", tf.range(window_shape[0]), x=window_shape[0], y=window_shape[1]),
        tfcv.model.einops.apply("y -> (x y)", tf.range(window_shape[1]), x=window_shape[0], y=window_shape[1]),
    ], axis=-1)

    window_indices = tfcv.model.einops.apply("p f -> b p f", window_indices, b=b, f=2) \
                   + tfcv.model.einops.apply("b f -> b 1 f", start, f=2)

    x = tfcv.model.einops.apply("s... f -> b s... f", xy_volume(s), b=b, f=2)

    x_window = tf.gather_nd(
        x,
        window_indices,
        batch_dims=1,
    ) # b p 2
    y_window = tf.gather_nd(
        y[..., tf.newaxis],
        window_indices,
        batch_dims=1,
    )[..., 0] # b p

    valid_y_window = tf.gather_nd(
        valid_y[..., tf.newaxis],
        window_indices,
        batch_dims=1,
    )[..., 0] # b p

    return x_window, y_window, valid_y_window

def gather_window_3d(y, valid_y, angles, window_shape):
    # y: b a s...

    b = tf.shape(y)[0]
    a = tf.shape(y)[1]
    s = tf.shape(y)[-2:]
    as_ = tf.shape(y)[1:]

    window_shape = np.asarray(window_shape)
    assert window_shape.shape == (3,)
    window_shape = window_shape.astype("int32")
    window_left_half = window_shape // 2
    window_right_half = window_shape - window_left_half

    argmax = tf.cast(argmax3d(y), "int32")
    argmax = tf.clip_by_value(argmax, window_left_half, as_ - window_right_half)
    start = argmax - window_left_half[tf.newaxis, :] # b 3

    window_shape = window_shape * (tf.shape(y)[0] * 0 + 1)
    window_indices = tf.stack([
        tfcv.model.einops.apply("a -> (a x y)", tf.range(window_shape[0]), a=window_shape[0], x=window_shape[1], y=window_shape[2]),
        tfcv.model.einops.apply("x -> (a x y)", tf.range(window_shape[1]), a=window_shape[0], x=window_shape[1], y=window_shape[2]),
        tfcv.model.einops.apply("y -> (a x y)", tf.range(window_shape[2]), a=window_shape[0], x=window_shape[1], y=window_shape[2]),
    ], axis=-1)

    window_indices = tfcv.model.einops.apply("p f -> b p f", window_indices, b=b, f=3) \
                   + tfcv.model.einops.apply("b f -> b 1 f", start, f=3)

    x = axy_volume(s, angles=angles)

    x_window = tf.gather_nd(
        x,
        window_indices,
        batch_dims=1,
    ) # b p 3
    y_window = tf.gather_nd(
        y[..., tf.newaxis],
        window_indices,
        batch_dims=1,
    )[..., 0] # b p
    valid_y_window = tf.gather_nd(
        valid_y[..., tf.newaxis],
        window_indices,
        batch_dims=1,
    )[..., 0] # b p

    return x_window, y_window, valid_y_window

def argmax2d(x):
    return tf.stack([
        tf.argmax(tf.reduce_max(x, axis=-1), axis=-1),
        tf.argmax(tf.reduce_max(x, axis=-2), axis=-1),
    ], axis=-1) # b... 2

def argmax3d(x):
    return tf.stack([
        tf.argmax(tf.reduce_max(tf.reduce_max(x, axis=-1), axis=-1), axis=-1),
        tf.argmax(tf.reduce_max(tf.reduce_max(x, axis=-1), axis=-2), axis=-1),
        tf.argmax(tf.reduce_max(tf.reduce_max(x, axis=-2), axis=-2), axis=-1),
    ], axis=-1) # b... 3

def corr_argmax_discrete(y, valid_y, angles):
    # y: b a s...

    b = tf.shape(y)[0]
    s = tf.shape(y)[-2:]

    max_angle_indices = tf.argmax(tf.reduce_max(tf.reduce_max(y, axis=-1), axis=-1), axis=-1) # b
    y = tf.gather(y, max_angle_indices[..., tf.newaxis], batch_dims=1, axis=1)[:, 0, :, :] # b s...

    angle = tf.gather(angles, max_angle_indices[..., tf.newaxis], batch_dims=1, axis=1)[..., 0] # b

    x = tfcv.model.einops.apply("s... f -> b s... f", xy_volume(s), b=b, f=2)

    indices = tf.argmax(tfcv.model.einops.apply("b s... -> b (s...)", y, b=b, s=s), axis=-1) # b
    result = tf.gather(tfcv.model.einops.apply("b s... f -> b (s...) f", x, b=b, s=s, f=2), indices[..., tf.newaxis], batch_dims=1, axis=1)[..., 0, :] # b 2
    result = tf.cast(result, "float32")

    return tf.concat([angle[:, tf.newaxis], result], axis=-1) # b 3

def corr_argmax_weighted(y, valid_y, angles):
    # y: b a s...

    y = tf.where(valid_y, y, 0.0)
    y = y / tfcv.model.einops.apply("b a s... -> b 1...", y, reduction="sum", output_ndims=4)

    axy = axy_volume(tf.shape(y)[-2:], angles=angles) # b a s... 3

    return tfcv.model.einops.apply("b a s... f -> b f", y[..., tf.newaxis] * axy, reduction="sum")

def corr_argmax_weighted_window(y, valid_y, angles, window_shape):
    # y: b a s...

    x, y, valid_y = gather_window_3d(y, valid_y, angles, window_shape)

    y = tf.where(valid_y, y, 0.0)
    y = y / tf.reduce_sum(y, axis=1, keepdims=True)

    result = tfcv.model.einops.apply("b p f -> b f", y[..., tf.newaxis] * x, reduction="sum")

    return result


def corr_argmax_paraboloid_2d(y, valid_y, angles, window_shape=3, max_deviation=1):
    # y: b a s...
    discrete_argmax = corr_argmax_discrete(y, valid_y, angles)

    b = tf.shape(y)[0]
    s = tf.shape(y)[-2:]

    max_angle_indices = tf.argmax(tf.reduce_max(tf.reduce_max(y, axis=-1), axis=-1), axis=-1) # b
    y = tf.gather(y, max_angle_indices[..., tf.newaxis], batch_dims=1, axis=1)[:, 0, :, :] # b s...
    valid_y = tf.gather(valid_y, max_angle_indices[..., tf.newaxis], batch_dims=1, axis=1)[:, 0, :, :] # b s...
    angle = tf.gather(angles, max_angle_indices[..., tf.newaxis], batch_dims=1, axis=1)[..., 0] # b

    x, y, valid_y = gather_window_2d(y, valid_y, window_shape)

    # Least-squares fit a 2d paraboloid
    A = tf.stack([
        x[..., 0] ** 2,
        x[..., 1] ** 2,
        x[..., 0] * x[..., 1],
        x[..., 0],
        x[..., 1],
        tf.ones_like(x[..., 0]),
    ], axis=-1) # b p 6
    b = y[..., tf.newaxis] # b p 1
    A = tf.where(valid_y[..., tf.newaxis], A, 0)
    b = tf.where(valid_y[..., tf.newaxis], b, 0)
    x = tf.linalg.lstsq(A, b, fast=False) # b 6 1

    # Find sub-pixel argmax of 2d paraboloid
    x = x[..., 0] # b 6
    a, b, c, d, e, f = x[..., 0], x[..., 1], x[..., 2], x[..., 3], x[..., 4], x[..., 5]
    nom = tf.stack([2 * d * b - c * e, 2 * a * e - d * c], axis=-1)
    denom = c * c - 4 * a * b

    discrete_argmax = discrete_argmax[:, 1:]
    result = tf.where(
        tf.math.abs(denom[..., tf.newaxis]) > 1e-8,
        tf.math.divide_no_nan(nom, denom[..., tf.newaxis]),
        discrete_argmax,
    ) # b 2
    result = tf.clip_by_value(result, discrete_argmax - max_deviation, discrete_argmax + max_deviation) # b 2
    result = tf.concat([angle[:, tf.newaxis], result], axis=-1) # b 3

    return result

def corr_argmax_paraboloid_3d(y, valid_y, angles, window_shape, max_deviation=1):
    # y: b a s...
    discrete_argmax = corr_argmax_discrete(y, valid_y, angles)

    b = tf.shape(y)[0]
    a = tf.shape(y)[1]
    s = tf.shape(y)[-2:]

    x, y, valid_y = gather_window_3d(y, valid_y, angles, window_shape)

    # Least-squares fit a 3d paraboloid
    A = tf.stack([
        x[..., 0] ** 2,
        x[..., 1] ** 2,
        x[..., 2] ** 2,
        x[..., 0] * x[..., 1],
        x[..., 0] * x[..., 2],
        x[..., 1] * x[..., 2],
        x[..., 0],
        x[..., 1],
        x[..., 2],
        tf.ones_like(x[..., 0]),
    ], axis=-1) # b p 10
    b = y[..., tf.newaxis] # b p 1
    A = tf.where(valid_y[..., tf.newaxis], A, 0)
    b = tf.where(valid_y[..., tf.newaxis], b, 0)
    x = tf.linalg.lstsq(A, b, fast=False) # b 10 1

    # Find sub-pixel argmax of 3d paraboloid
    x = x[..., 0] # b 10
    A = tf.transpose(tf.convert_to_tensor([
        [2 * x[:, 0], x[:, 3], x[:, 4]],
        [2 * x[:, 1], x[:, 3], x[:, 5]],
        [2 * x[:, 2], x[:, 4], x[:, 5]],
    ]), (2, 0, 1)) # b 3 3
    b = tf.transpose(tf.convert_to_tensor([
        -x[:, 6], -x[:, 7], -x[:, 8],
    ]), (1, 0))[..., tf.newaxis] # b 3 1
    x = tf.linalg.lstsq(A, b, fast=False)[..., 0] # b 3

    result = x
    result = tf.clip_by_value(result, discrete_argmax - max_deviation, discrete_argmax + max_deviation) # b 3

    return result # b 3












#
# def corr_argmax_paraboloid_3d(x, y):
#     assert False, "Not working yet"
#     # Least-squares fit a 3d paraboloid
#     # TODO: rescale translations to (-1, 1) and scale back at the end of least squares
#     A = tf.stack([
#         axy_window[..., 0] ** 2,
#         axy_window[..., 1] ** 2,
#         axy_window[..., 2] ** 2,
#         axy_window[..., 0] * axy_window[..., 1],
#         axy_window[..., 0] * axy_window[..., 2],
#         axy_window[..., 1] * axy_window[..., 2],
#         axy_window[..., 0],
#         axy_window[..., 1],
#         axy_window[..., 2],
#         tf.ones_like(axy_window[..., 0]),
#     ], axis=-1) # [batch, points, 10]
#     b = corr_window[..., tf.newaxis] # [batch, points, 1]
#     x = tf.linalg.lstsq(A, b, fast=False) # [batch, 10, 1]
#
#     # Find maximum of3d paraboloid
#     x = x[..., 0]
#     A = tf.transpose(tf.convert_to_tensor([
#         [2 * x[:, 0], x[:, 3], x[:, 4]],
#         [2 * x[:, 1], x[:, 3], x[:, 5]],
#         [2 * x[:, 2], x[:, 4], x[:, 5]],
#     ]), (2, 0, 1)) # [batch, 3, 3]
#     b = tf.transpose(tf.convert_to_tensor([
#         -x[:, 6], -x[:, 7], -x[:, 8],
#     ]), (1, 0))[..., tf.newaxis] # [batch, 3, 1]
#     x = tf.linalg.solve(A, b) # [batch, 3, 1]
#
#     poses = x[..., 0] # [batch, axy]





def cross_correlate_2d(i1, i2, method="fft", pad=None):
    tf.debugging.assert_all_finite(i1, "cross_correlate_2d got non-finite i1")
    tf.debugging.assert_all_finite(i2, "cross_correlate_2d got non-finite i2")
    method = method.lower()

    if len(i1.shape) != len(i2.shape):
        raise ValueError("Input tensors should have the same rank")

    batch_shape = tf.shape(i1)[:-3]
    batch_rank = len(i1.shape) - 3
    if batch_rank == 0:
        i1 = i1[tf.newaxis, ...]
        i2 = i2[tf.newaxis, ...]
    elif batch_rank > 1:
        i1 = tf.reshape(i1, [-1, tf.shape(i1)[-3], tf.shape(i1)[-2], tf.shape(i1)[-1]])
        i2 = tf.reshape(i2, [-1, tf.shape(i2)[-3], tf.shape(i2)[-2], tf.shape(i2)[-1]])

    # i1: [batch, dims..., filters]
    # i2: [batch, dims..., filters]
    image_shape = tf.shape(i1)[1:3]
    batchsize = tf.shape(i1)[0]
    height = tf.shape(i1)[1]
    width = tf.shape(i1)[2]
    filters = tf.shape(i1)[3]

    if not pad is None:
        pad_x, pad_y = pad
        front = (0, 0)
        back = image_shape

        paddings = [[0, 0], [front[0], back[0]], [0, 0], [0, 0]]
        i1 = tf.pad(i1, paddings, **pad_x)
        i2 = tf.pad(i2, paddings, **pad_x)

        paddings = [[0, 0], [0, 0], [front[1], back[1]], [0, 0]]
        i1 = tf.pad(i1, paddings, **pad_y)
        i2 = tf.pad(i2, paddings, **pad_y)

    if method.startswith("fft"):
        i1 = tf.transpose(i1, [len(i1.shape) - 1] + list(range(len(i1.shape) - 1))) # [filters, batch, dims...]
        i2 = tf.transpose(i2, [len(i1.shape) - 1] + list(range(len(i1.shape) - 1))) # [filters, batch, dims...]

        f1 = tf.signal.rfft2d(i1)
        f2 = tf.signal.rfft2d(i2)
        f2c = tf.math.conj(f2)
        r = f1 * f2c
        corr = tf.signal.irfft2d(r)

        tf.debugging.assert_all_finite(corr, "cross_correlate_2d produced non-finite corr")

        corr = tf.transpose(corr, list(range(1, len(i1.shape))) + [0])

        corr = tf.roll(corr, (tf.shape(corr)[1] - 1) // 2, axis=1)
        corr = tf.roll(corr, (tf.shape(corr)[2] - 1) // 2, axis=2)

    elif method == "conv":
        i1 = tf.transpose(i1, [1, 2, 0, 3]) # [dims..., batch, filters]
        i2 = tf.transpose(i2, [1, 2, 0, 3]) # [dims..., batch, filters]
        i1 = tf.reshape(i1, [height, width, batchsize * filters]) # [dims..., batch * filters]
        i2 = tf.reshape(i2, [height, width, batchsize * filters]) # [dims..., batch * filters]

        i1 = i1[tf.newaxis, ...] # [1, dims..., batch * filters]
        i2 = i2[..., tf.newaxis] # [dims..., batch * filters, 1]
        corr = tf.nn.depthwise_conv2d(i1, i2, strides=[1, 1, 1, 1], dilations=[1, 1], padding="SAME") # [1, dims..., batch * filters]

        corr = tf.reshape(corr, [height, width, batchsize, filters]) # [dims..., batch, filters]
        corr = tf.transpose(corr, [2, 0, 1, 3]) # [batch, dims..., filters]
    else:
        raise ValueError(f"Invalid method argument {method}")

    if not pad is None:
        begin = image_shape // 2
        end = tf.shape(corr)[1:3] - (image_shape + 1) // 2
        corr = corr[:, begin[0]:end[0], begin[1]:end[1], :]

    if batch_rank == 0:
        corr = corr[0, ...]
    elif batch_rank > 1:
        corr = tf.reshape(corr, tf.concat([batch_shape, tf.shape(corr)[-3:]], axis=0))

    return corr

def xy_volume(image_shape, layout="conv", dtype="float32"):
    image_shape = tf.cast(image_shape, "int32")
    shape = tf.stack([image_shape[0], image_shape[1]], axis=0)

    translations_x = tf.cast(tf.range(image_shape[0]), "int32")
    translations_y = tf.cast(tf.range(image_shape[1]), "int32")
    if layout == "fft":
        translations_x = -((translations_x + image_shape[0] // 2) % image_shape[0] - image_shape[0] // 2)
        translations_y = -((translations_y + image_shape[1] // 2) % image_shape[1] - image_shape[1] // 2)
    elif layout == "conv":
        translations_x = -(translations_x - (image_shape[0] + 1) // 2 + 1)
        translations_y = -(translations_y - (image_shape[1] + 1) // 2 + 1)
    else:
        raise ValueError(f"Invalid layout {layout}")

    translations_x = tf.cast(translations_x, dtype)
    translations_y = tf.cast(translations_y, dtype)

    translations_x = tf.broadcast_to(translations_x[:, tf.newaxis], shape) # [dims...]
    translations_y = tf.broadcast_to(translations_y[tf.newaxis, :], shape) # [dims...]

    volume = tf.stack([translations_x, translations_y], axis=-1) # [dims..., xy]
    return volume

def axy_volume(image_shape, angles, layout="conv", dtype="float32"):
    image_shape = tf.cast(image_shape, "int32")
    shape = tf.stack([tf.shape(angles)[0], tf.shape(angles)[1], image_shape[0], image_shape[1]], axis=0)

    translations_xy = xy_volume(image_shape, layout=layout, dtype=dtype) # [dims..., xy]

    angles = tf.cast(angles, dtype)
    translations_xy = tf.cast(translations_xy, dtype)

    angles = tf.broadcast_to(angles[:, :, tf.newaxis, tf.newaxis], shape)[..., tf.newaxis] # [batch, angles, dims..., a]
    translations_xy = tf.broadcast_to(translations_xy[tf.newaxis, tf.newaxis, :, :], tf.concat([shape, [2]], axis=0)) # [batch, angles, dims..., xy]

    volume = tf.concat([angles, translations_xy], axis=-1) # [batch, angles, dims..., axy]
    return volume
#
# def argmax_window_indices_1d(y, window_shape=3):
#     # y: [batch..., dim]
#     if isinstance(window_shape, int):
#         window_shape = np.asarray([window_shape])
#     window_shape = tf.cast(window_shape, "int32")
#
#     argmax = tf.stack([tf.argmax(y, axis=-1)], axis=-1) # [batch..., 1]
#     argmax = tf.cast(argmax, "int32")
#     window_left = window_shape // 2
#     window_right = window_shape - window_left
#     argmax = tf.clip_by_value(argmax, window_left, tf.shape(y)[-1] - window_right)
#     start = argmax - window_left # [batch..., 1]
#
#     # Expand to indices for entire window
#     window_indices = tf.range(window_shape[0])[:, tf.newaxis] # [window_dims..., 1]
#     window_indices = tf.reshape(window_indices, [-1, 1]) # [points, 1]
#     while len(window_indices.shape) - 2 < len(y.shape) - 1:
#         window_indices = window_indices[tf.newaxis, ...]
#     window_indices = window_indices + start[..., tf.newaxis, :] # [batch..., points, 1]
#
#     return window_indices
#
# def argmax_window_indices_2d(y, window_shape=3):
#     # y: [batch..., dims...]
#     if isinstance(window_shape, int):
#         window_shape = np.asarray([window_shape, window_shape])
#     window_shape = window_shape.astype("int32")
#     while len(window_shape.shape) - 1 < len(y.shape) - 2:
#         window_shape = window_shape[np.newaxis, ...]
#     # window_shape [batch..., 2]
#     window_left_half = window_shape // 2
#     window_right_half = window_shape - window_left_half
#
#     argmax = tf.stack([
#         tf.argmax(tf.reduce_max(y, axis=-1), axis=-1),
#         tf.argmax(tf.reduce_max(y, axis=-2), axis=-1),
#     ], axis=-1) # [batch..., 2]
#     argmax = tf.cast(argmax, "int32")
#     argmax = tf.clip_by_value(argmax, window_left_half, tf.shape(y)[-2:] - window_right_half)
#     start = argmax - window_left_half # [batch..., 2]
#
#     # Expand to indices for entire window
#     window_indices = np.stack(np.meshgrid(
#         np.arange(np.squeeze(window_shape)[0]),
#         np.arange(np.squeeze(window_shape)[1]),
#         indexing="ij",
#     ), axis=-1) # [window_dims..., 2]
#     window_indices = np.reshape(window_indices, [-1, 2]) # [points, 2]
#     while len(window_indices.shape) - 2 < len(y.shape) - 2:
#         window_indices = window_indices[np.newaxis, ...]
#     window_indices = window_indices * tf.math.minimum(tf.shape(y)[0], 1) # This is unnecessary, but tf.keras.models.load_model crashes without it
#     window_indices = window_indices + start[..., tf.newaxis, :] # [batch..., points, 2]
#     return window_indices
#
# def register_2d_rotation(image1, image2, polar_shape, radius=None, subpixel_sample=subpixel_argmax, subpixel_window_shape=3):
#     if radius is None:
#         radius = tf.cast(tf.math.minimum(tf.shape(image1)[-3], tf.shape(image1)[-2]), "float32") / 2.0
#
#     # Transform to polar
#     image1 = warp_polar(image1, radius=radius, output_shape=polar_shape)
#     image2 = warp_polar(image2, radius=radius, output_shape=polar_shape)
#     image1 = image1[:, :tf.shape(image1)[1] // 2, :]
#     image2 = image2[:, :tf.shape(image2)[1] // 2, :]
#
#     # Cross correlate via fft: Repeats symmetrically in angle dimension, and everything in scale dimension but center point is discarded so padding is irrelevant there
#     corr = cross_correlate_2d(image1, image2, "fft") # [batch, angle, scale, features]
#     corr = tf.reduce_sum(corr, axis=-1) # [batch, angle, scale]
#     xy = xy_volume(tf.shape(image1)[1:3])[tf.newaxis, ...] # [batch, angle, scale, 2]
#
#     # Drop scale
#     no_scale_y = (polar_shape[1] + 1) // 2 - 1 # TODO: assert xy at no_scale_y is actually zero, check in warp_polar how coordinates are created
#     corr = corr[:, :, no_scale_y] # [angle]
#     angles = xy[:, :, no_scale_y, :1] # [angle, 1]
#
#     window_indices = argmax_window_indices_1d(corr, window_shape=subpixel_window_shape) # [batch, points, 1]
#     angles_window = tf.gather_nd(angles, window_indices, batch_dims=1) # [batch, points, 1]
#     corr_window = tf.gather_nd(corr, window_indices, batch_dims=1) # [batch, points]
#
#     argmax = subpixel_sample(angles_window, corr_window) # [batch, 1]
#
#     return -argmax[:, 0] * 2.0 * math.pi / float(polar_shape[0])
#
# def register_2d(image1, image2, subpixel_sample=subpixel_argmax, subpixel_window_shape=3):
#     corr = cross_correlate_2d(image1, image2, "fft", pad=(({"mode": "CONSTANT"}, {"mode": "CONSTANT"}))) # [batch, dims..., features]
#     corr = tf.reduce_sum(corr, axis=-1)
#     xy = xy_volume(tf.shape(image1)[1:3]) # [batch, dims..., 2]
#
#     window_indices = argmax_window_indices_2d(corr, window_shape=subpixel_window_shape) # [batch, points, 2]
#     xy_window = tf.gather_nd(xy, window_indices, batch_dims=1) # [points, 2]
#     corr_window = tf.gather_nd(corr, window_indices, batch_dims=1) # [points]
#
#     return subpixel_sample(xy_window, corr_window)
#
# def polar_coordinate_map(shape, center, radius):
#     angle = -tf.linspace(-math.pi / 2, math.pi * 3 / 2, num=shape[0] + 1)[:-1]
#     radius = tf.linspace(0.0, radius, num=shape[1] + 1)[:-1]
#
#     angle = tf.broadcast_to(angle[:, tf.newaxis], shape)
#     radius = tf.broadcast_to(radius[tf.newaxis, :], shape)
#
#     ar = tf.stack([angle, radius], axis=-1) # [angle, radius, ar]
#
#     x = ar[..., 1] * tf.math.cos(ar[..., 0]) + center[0]
#     y = ar[..., 1] * tf.math.sin(ar[..., 0]) + center[1]
#
#     xy = tf.stack([x, y], axis=-1) # [angle, radius, xy]
#
#     return xy
#
# def warp_polar(image, radius, output_shape): # TODO: linear or log mapping
#     # image: [batch, dims..., features]
#     input_shape = np.asarray(image.shape[-3:-1])
#
#     xy = polar_coordinate_map(output_shape, tf.cast(input_shape, "float32") / 2.0, radius) # [angle, radius, 2]
#
#     mask = tf.math.reduce_all(tf.math.logical_and(0 <= xy, xy < input_shape[tf.newaxis, tf.newaxis, :]), axis=-1) # [angle, radius]
#     xy = tf.where(mask[..., tf.newaxis], xy, 0) # [angle, radius, 2]
#
#     result = tf.gather_nd(image, tf.cast(xy, "int32")[tf.newaxis, ...], batch_dims=1) # [batch, angle, radius, features]
#     result = tf.where(mask[tf.newaxis, ..., tf.newaxis], result, 0.0) # [batch, angle, radius, features]
#
#     return result
