import tensorflow as tf
import tfcv, math, georeg
from functools import partial

def up(bev, stride, name, config):
    bev = tfcv.model.util.resize(bev, tf.shape(bev)[-3:-1] * stride, method="bilinear", config=config)
    bev = tfcv.model.util.conv(bev, kernel_size=1, stride=1, name=name, config=config)
    return bev

def down(bev, stride, name, config):
    bev = tfcv.model.util.norm_conv(bev, kernel_size=stride, stride=stride, name=name, config=config)
    return bev

def resize(bev, in_stride, out_stride, name, config):
    if in_stride > out_stride:
        print(f"Up {name} from {in_stride} to {out_stride}")
        return up(bev, in_stride // out_stride, name, config=config)
    elif in_stride < out_stride:
        print(f"Down {name} from {in_stride} to {out_stride}")
        return down(bev, out_stride // in_stride, name, config=config)
    else:
        return bev

def decode(x, filters, image, norm, strides, name, config):
    assert len(filters) == len(strides)

    # Find encoder blocks
    if not isinstance(x, list):
        xs = [tfcv.model.graph.get_unique(x, pred=lambda layer: layer.name.endswith(f"block{i}") and layer.name.startswith(name)) for i in [1, 2, 3]] + [x]
    else:
        xs = [t for t in x]
    if norm:
        xs = [tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "neck", f"{i + 1}", f"norm"), config=config) for i, x in enumerate(xs)]
    xs_strides = [4, 8, 16, 32]

    # Add lower stride blocks
    if not image is None:
        block = tfcv.model.resnet.basic_block_v1
        # block = tfcv.model.convnext.block
        if len(strides) >= 2 and strides[0] == 1 and strides[1] == 2:
            x = image
            x1 = block(x, filters=filters[0], stride=1, dilation_rate=1, name=tfcv.model.util.join(name, "decode-shortcut", f"1"), config=config)
            x2 = block(x1, filters=filters[1], stride=2, dilation_rate=1, name=tfcv.model.util.join(name, "decode-shortcut", f"2"), config=config)
            xs = [x1, x2] + xs
            xs_strides = [1, 2] + xs_strides
        elif strides[0] == 2:
            x = image
            x = block(x, filters=filters[0], stride=1, dilation_rate=1, name=tfcv.model.util.join(name, "decode-shortcut", f"1"), config=config)
            x = block(x, filters=filters[0], stride=2, dilation_rate=1, name=tfcv.model.util.join(name, "decode-shortcut", f"2"), config=config)
            xs = [x] + xs
            xs_strides = [2] + xs_strides
        elif strides[0] == 1:
            x = image
            x = block(x, filters=filters[0], dilation_rate=1, name=tfcv.model.util.join(name, "decode-shortcut", f"1"), config=config)
            x = block(x, filters=filters[0], dilation_rate=2, name=tfcv.model.util.join(name, "decode-shortcut", f"2"), config=config)
            xs = [x] + xs
            xs_strides = [1] + xs_strides
        else:
            assert False

    # Fuse strides
    ys = []
    for out_index in reversed(list(range(len(filters)))):
        out_stride = strides[out_index]
        num_blocks = len([s for s in xs_strides if s >= out_stride])

        xs2 = xs[-num_blocks:]
        xs_strides2 = xs_strides[-num_blocks:]

        print(f"Backbone decoder for {name}: Output stride {out_stride} uses encoder blocks with strides {xs_strides2}")

        for i in range(len(xs2)):
            xs2[i] = tfcv.model.util.conv(xs2[i], filters=filters[out_index], kernel_size=1, stride=1, name=tfcv.model.util.join(name, f"out-stride-{out_stride}", f"in{i + 1}"), config=config)

        for i in range(1, len(xs2)):
            xs2[i] = tfcv.model.util.resize(xs2[i], tf.shape(xs2[0])[1:-1], method="bilinear", config=config)
        x = tf.math.add_n(xs2)

        x = tfcv.model.util.conv_norm_act(x, kernel_size=1, stride=1, name=tfcv.model.util.join(name, f"out-stride-{out_stride}", f"out"), config=config)

        # Final upsampling if image is not given
        if xs_strides2[0] != out_stride:
            print("    => requires final bilinear upsampling", xs_strides2[0] // out_stride)
            assert image is None
            assert xs_strides2[0] > out_stride
            x = tfcv.model.util.resize(x, tf.shape(x)[1:-1] * (xs_strides2[0] // out_stride), method="bilinear", config=config)

        ys.append(x)

        xs = xs[:-num_blocks] + [x]
        xs_strides = xs_strides[:-num_blocks] + [out_stride]

    return list(reversed(ys))

def reduce_tokens(x, num, name=None, config=tfcv.model.config.Config()):
    # x: [batch, tokens, features]

    weights = x # TODO: combine weights/conv1 and conv? try
    weights = tfcv.model.util.conv(weights, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "weights", "conv1"), config=config)
    weights = tfcv.model.util.act(weights, name=tfcv.model.util.join(name, "weights", "act"))
    weights = tfcv.model.util.conv(weights, filters=num, kernel_size=1, stride=1, bias=False, name=tfcv.model.util.join(name, "weights", "conv2"), config=config)
    weights = tf.nn.softmax(weights, axis=1) # [batch, tokens, reduced_tokens]

    x = tfcv.model.util.conv(x, kernel_size=1, stride=1, bias=False, name=tfcv.model.util.join(name, "conv"), config=config) # [batch, tokens, features]

    x = tf.linalg.matmul(weights, x, transpose_a=True) # [batch, reduced_tokens, features]

    return x
# TODO: tfcv https://ai.googleblog.com/2021/12/improving-vision-transformer-efficiency.html
def expand_tokens(x_orig, x, use_norm=True, name=None, config=tfcv.model.config.Config()):
    # x_orig: [batch, tokens, features]
    # x: [batch, reduced_tokens, features]

    weights = x_orig
    weights = tfcv.model.util.conv(weights, filters=x.shape[1], kernel_size=1, stride=1, bias=False, name=tfcv.model.util.join(name, "weights", "conv"), config=config)
    weights = tf.nn.sigmoid(weights) # [batch, tokens, reduced_tokens]
    # TODO: try softmax instead

    if use_norm:
        x = tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "norm1"), config=config)
    x = tf.transpose(x, (0, 2, 1))
    x = tfcv.model.util.conv(x, kernel_size=1, stride=1, bias=False, name=tfcv.model.util.join(name, "channelwise_conv"), config=config)
    x = tf.transpose(x, (0, 2, 1))
    if use_norm:
        x = tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "norm2"), config=config)

    weights = weights[..., tf.newaxis] # [batch, tokens, reduced_tokens, 1]
    x = x[:, tf.newaxis, ...] # [batch, 1, reduced_tokens, features]
    x = tf.reduce_sum(weights * x, axis=2)  # [batch, tokens, features]

    return x

def positional_encoding_1d(length, features, dtype):
    pos = tf.cast(tf.range(length), dtype)
    i_num = features // 2
    i = tf.cast(tf.range(i_num), dtype)
    factor = tf.math.exp(2 * i * -math.log(10000.0) / tf.cast(i_num, dtype))

    angles = pos[:, tf.newaxis] * factor[tf.newaxis, :]

    result = tf.concat([
        tf.math.sin(angles),
        tf.math.cos(angles),
    ], axis=-1)

    return result # [length, features]

def positional_encoding_2d(shape, features, dtype):
    result1 = positional_encoding_1d(shape[0], features // 2, dtype=dtype) # [shape[0], features]
    result2 = positional_encoding_1d(shape[1], features // 2, dtype=dtype) # [shape[1], features]

    shape = tf.concat([shape, [features // 2]], axis=0)
    result = tf.concat([
        tf.broadcast_to(result1[:, tf.newaxis, :], shape),
        tf.broadcast_to(result2[tf.newaxis, :, :], shape),
    ], axis=-1)

    return result # [shape..., features]

def positional_encoding_polar(polar_image, features):
    i_num = features // 4
    i = tf.cast(tf.range(i_num), polar_image.dtype)
    factor = tf.math.exp(2 * i * -math.log(10000.0) / tf.cast(i_num, polar_image.dtype))

    angles = polar_image[:, :, :, :, tf.newaxis] * factor[tf.newaxis, tf.newaxis, tf.newaxis, tf.newaxis, :]

    result = tf.concat([
        tf.math.sin(angles[:, :, :, 0, :]),
        tf.math.cos(angles[:, :, :, 0, :]),
        tf.math.sin(angles[:, :, :, 1, :]),
        tf.math.cos(angles[:, :, :, 1, :]),
    ], axis=-1)

    return result

class PolarPositionalEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        if filters // 4 * 4 != filters:
            raise ValueError(f"Filters must be divisible by 4, but got {filters}")
        self.filters = filters

    def build(self, input_shape):
        self.positional_encoding_factor = self.add_weight("positional_encoding_factor", shape=(self.filters // 4,), trainable=True)

        i_num = self.filters // 4
        i = tf.cast(tf.range(i_num), self.positional_encoding_factor.dtype)
        factor = tf.math.exp(2 * i * -math.log(10000.0) / tf.cast(i_num, self.positional_encoding_factor.dtype) - 5.0)

        self.positional_encoding_factor.assign(factor)

    def call(self, query_points): # el az
        args = tfcv.model.einops.apply("b... f1, f2 -> b... f1 f2", query_points, self.positional_encoding_factor, f1=2)

        pos_enc = tf.concat([
            tf.math.sin(args[..., 0, :]),
            tf.math.cos(args[..., 0, :]),
            tf.math.sin(args[..., 1, :]),
            tf.math.cos(args[..., 1, :]),
        ], axis=-1)

        return pos_enc

    def get_config(self):
        config = super().get_config()
        config["filters"] = self.filters
        return config

class QueriedPositionalEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, shape, **kwargs):
        super().__init__(**kwargs)
        self.shape = shape

    def build(self, input_shape):
        self.positional_encoding = self.add_weight("positional_encoding", shape=self.shape, trainable=True)

    def call(self, query_points):
        batch_shape = tf.shape(query_points)[:-1]
        query_points = tfcv.model.einops.apply("b... f -> 1 (b...) f", query_points, f=2)
        query_points = query_points * tf.cast(tf.convert_to_tensor(self.shape)[tf.newaxis, tf.newaxis, :2], query_points.dtype)
        pos_enc = tfa.image.interpolate_bilinear(self.positional_encoding[tf.newaxis, ...], query_points)
        pos_enc = tfcv.model.einops.apply("1 (b...) f -> b... f", pos_enc, b=batch_shape)
        return pos_enc

    def get_config(self):
        config = super().get_config()
        config["shape"] = self.shape
        return config

class PositionalEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, shape, **kwargs):
        super().__init__(**kwargs)
        self.shape = shape
        self.positional_encoding = self.add_weight("positional_encoding", shape=self.shape, trainable=True)

    def call(self, shape):
        return tfcv.model.util.resize(self.positional_encoding[tf.newaxis, ...], shape, method="bicubic")

    def get_config(self):
        config = super().get_config()
        config["shape"] = self.shape
        return config

class PositionalEncodingNoResizeLayer(tf.keras.layers.Layer):
    def __init__(self, shape, **kwargs):
        super().__init__(**kwargs)
        self.shape = shape
        self.positional_encoding = self.add_weight("positional_encoding", shape=self.shape, trainable=True)

    def call(self, dummy):
        return self.positional_encoding + tf.cast(tf.shape(dummy)[0] * 0, self.positional_encoding.dtype)

    def get_config(self):
        config = super().get_config()
        config["shape"] = self.shape
        return config

class Weight(tf.keras.layers.Layer):
    def __init__(self, shape, **kwargs):
        super().__init__(**kwargs)
        self.shape = shape
        self.weight = self.add_weight("weight", shape=self.shape, trainable=True)

    def call(self, dummy):
        return self.weight + tf.cast(tf.shape(dummy)[0] * 0, self.weight.dtype)

    def get_config(self):
        config = super().get_config()
        config["shape"] = self.shape
        return config

# def scatter(indices, updates, length, batch_dims=0, mode="add"):
#     if batch_dims > 0:
#         coords = tf.stack(tf.meshgrid(*[tf.range(tf.shape(indices)[b]) for b in range(batch_dims)], indexing="ij"), axis=-1)
#         coords = tf.concat([
#             tfcv.model.einops.apply("b... batch_dims -> b... indices batch_dims", coords, indices=tf.shape(indices)[batch_dims]),
#             tfcv.model.einops.apply("b... indices -> b... indices 1", indices),
#         ], axis=-1)
#     else:
#         coords = indices
#
#     shape = tf.concat([tf.shape(indices)[:batch_dims], [length], [tf.shape(updates)[batch_dims + 1]]], axis=0)
#     if mode == "add" or mode == "sum":
#         tensor = tf.zeros(shape, dtype=updates.dtype)
#         tensor = tf.tensor_scatter_nd_add(tensor, coords, updates)
#     elif mode == "max":
#         tensor = tf.fill(shape, tf.cast(-math.inf, dtype=updates.dtype))
#         tensor = tf.tensor_scatter_nd_max(tensor, coords, updates)
#     elif mode == "min":
#         tensor = tf.fill(shape, tf.cast(math.inf, dtype=updates.dtype))
#         tensor = tf.tensor_scatter_nd_min(tensor, coords, updates)
#     elif mode == "any":
#         updates = tf.where(updates, 1, 0)
#         tensor = tf.zeros(shape, dtype=updates.dtype)
#         tensor = tf.tensor_scatter_nd_max(tensor, coords, updates)
#         tensor = tensor > 0
#     else:
#         raise ValueError(f"Invalid mode parameter '{mode}'")
#     return tensor


def self_attention(x, filters=None, heads=8, qkv_bias=True, attention=tfcv.model.transformer.full_attention, shortcut=tfcv.model.stochasticdepth.shortcut, name=None, config=tfcv.model.config.Config()):
    if filters is None:
        filters = x.shape[-1]

    x_orig = x
    x = tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "norm"), config=config)

    x = tfcv.model.util.conv(x, filters=3 * filters, kernel_size=1, stride=1, bias=qkv_bias, name=tfcv.model.util.join(name, "in_proj"), config=config)
    query, key, value = tf.split(x, num_or_size_splits=3, axis=-1)

    x = attention(query, key, value, heads=heads, name=tfcv.model.util.join(name, "attention"), config=config)
    x = tfcv.model.util.conv(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, bias=qkv_bias, name=tfcv.model.util.join(name, "out_proj"), config=config)

    x = shortcut(x_orig, x, name=tfcv.model.util.join(name, "shortcut"), config=config)
    return x

def cross_attention(q, kv, filters_qk=None, filters_v=None, heads=8, qkv_bias=True, attention=tfcv.model.transformer.full_attention, shortcut=tfcv.model.stochasticdepth.shortcut, name=None, config=tfcv.model.config.Config()):
    if filters_qk is None:
        filters_qk = q.shape[-1]
    if filters_v is None:
        filters_v = kv.shape[-1]

    x_orig = q
    q = tfcv.model.util.norm(q, name=tfcv.model.util.join(name, "norm_q"), config=config)
    kv = tfcv.model.util.norm(kv, name=tfcv.model.util.join(name, "norm_kv"), config=config)

    query = tfcv.model.util.conv(q, filters=filters_qk, kernel_size=1, stride=1, bias=qkv_bias, name=tfcv.model.util.join(name, "in_proj_q"), config=config)
    kv = tfcv.model.util.conv(kv, filters=filters_qk + filters_v, kernel_size=1, stride=1, bias=qkv_bias, name=tfcv.model.util.join(name, "in_proj_kv"), config=config)
    key = kv[..., :filters_qk]
    value = kv[..., filters_qk:]

    x = attention(query, key, value, name=tfcv.model.util.join(name, "attention"), config=config)
    x = tfcv.model.util.conv(x, filters=x_orig.shape[-1], kernel_size=1, stride=1, bias=qkv_bias, name=tfcv.model.util.join(name, "out_proj"), config=config)

    x = shortcut(x_orig, x, name=tfcv.model.util.join(name, "shortcut"), config=config)
    return x

def mlp(x, filters=None, kernel_size=None, qkv_bias=True, shortcut=tfcv.model.stochasticdepth.shortcut, mlp_layers=2, name=None, config=tfcv.model.config.Config()):
    if filters is None:
        filters = x.shape[-1]

    x_orig = x
    x = tfcv.model.util.norm(x, name=tfcv.model.util.join(name, "norm"), config=config)
    for i in range(mlp_layers):
        x = tfcv.model.util.conv(x, filters=filters if i < mlp_layers - 1 else x_orig.shape[-1], kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, f"{i + 1}", "pointwise"), config=config)
        if i < mlp_layers - 1:
            if not kernel_size is None:
                x = tfcv.model.util.conv(x, filters=filters, groups=filters, kernel_size=kernel_size, stride=1, bias=True, name=tfcv.model.util.join(name, f"{i + 1}", "depthwise"), config=config)
            x = tfcv.model.util.act(x, config=config)

    x = shortcut(x_orig, x, name=tfcv.model.util.join(name, "shortcut"), config=config)
    return x

def rezero(x, name=None, config=tfcv.model.config.Config()):
    return tfcv.model.util.ScaleLayer(initial_value=0, axis=[], name=tfcv.model.util.join(name, "scale"))(x)

def rezero_shortcut_add(shortcut, residual, name=None, config=tfcv.model.config.Config()):
    return shortcut + rezero(residual, name=name)
