import tensorflow as tf
import tfcv, re
import numpy as np
from ... import stochasticdepth, convnext
from ... import config as config_
from functools import partial

color_mean = np.asarray([0.485, 0.456, 0.406])
color_std = np.asarray([0.229, 0.224, 0.225])

def preprocess(color):
    color = color / 255.0
    color = (color - color_mean) / color_std
    return color

def convert_name(name):
    name = "/".join(name.split("/")[1:])

    name = name.replace("stem/conv", "stem.0")
    name = name.replace("stem/norm", "stem.1")

    name = re.sub("^block([0-9]*)/unit([0-9]*)", lambda m: f"stages.{int(m.group(1)) - 1}.blocks.{int(m.group(2)) - 1}", name)

    name = name.replace("scale", "gamma")
    name = name.replace("depthwise", "conv_dw")
    name = re.sub("pointwise/([0-9]*)", lambda m: f"mlp.fc{int(m.group(1))}", name)

    def func(n):
        return 1 if n == "conv" else 0
    name = re.sub("downsample([0-9]*)/([a-z]*)", lambda m: f"stages.{int(m.group(1)) - 1}.downsample.{func(m.group(2))}", name)

    name = name.replace("/", ".")

    return name

config = config_.PytorchConfig(
    norm=lambda x, *args, **kwargs: tf.keras.layers.LayerNormalization(*args, epsilon=1e-6, **kwargs)(x),
    resize=config_.partial_with_default_args(config_.resize, align_corners=False),
    act=lambda x, **kwargs: tf.keras.layers.Activation(tf.keras.activations.gelu, **kwargs)(x),
)

def create_x(input, convnext_variant, url, name=None):
    return_model = input is None
    if input is None:
        input = tf.keras.layers.Input((None, None, 3))

    x = input

    shortcut = partial(stochasticdepth.shortcut, drop_probability=0.0, scale_at_train_time=True)
    block = partial(convnext.block, shortcut=shortcut, factor=4)
    x = convnext_variant(x, block=block, name=name, config=config)

    model = tf.keras.Model(inputs=[input], outputs=[x])

    weights = tf.keras.utils.get_file(url.split("/")[-1], url)
    tfcv.model.pretrained.weights.load_pth(weights, model, convert_name, ignore=lambda n: n.startswith("head."))

    return model if return_model else x

def make_builder(variant, url):
    class builder:
        @staticmethod
        def create(input=None, name=f"convnext_{variant}"):
            return create_x(
                input=input,
                convnext_variant=vars(convnext)[f"convnext_{variant}"],
                url=url,
                name=name,
            )

        preprocess = preprocess
        config = config
    return builder

class convnext_atto_imagenet1k_224(make_builder("atto", f"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth")): pass
class convnext_femto_imagenet1k_224(make_builder("femto", f"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth")): pass
class convnext_pico_imagenet1k_224(make_builder("pico", f"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth")): pass
class convnext_nano_imagenet1k_224(make_builder("nano", f"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth")): pass
