import numpy as np
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential
from tensorflow.keras.regularizers import l2

from sarnn.components import BatchNormSparsityRegularizer, L1WeightRegularizer, PostBatchNormOffset
from sarnn.utils import count_layer_synapses, layer_typename
from snn.imagenet import IMAGENET_H, IMAGENET_W, IMAGENET_CLASSES
from snn.voc import VOC_CLASSES, voc_item_size


def conv_cifar(
        n_classes,
        beta_penalty=None,
        use_offset=False,
        l1_synapse_penalty=None,
        l2_decay=None):
    model = Sequential()

    def _add_conv_block(filters):
        model.add(Conv2D(
            filters, 3,
            padding="same",
            kernel_regularizer=_kernel_regularizer(
                l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    model.add(Conv2D(
        64, 3,
        padding="same",
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay),
        input_shape=(32, 32, 3)))
    _add_bn_relu(model, beta_penalty, use_offset)
    _add_conv_block(64)
    _add_conv_block(64)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(128)
    _add_conv_block(128)
    _add_conv_block(128)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(256)
    _add_conv_block(256)
    _add_conv_block(256)
    model.add(AveragePooling2D(pool_size=2))

    model.add(Flatten())
    model.add(Dense(
        512,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    _add_bn_relu(model, beta_penalty, use_offset)
    model.add(Dense(
        n_classes,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(Softmax())

    _count_neurons(model)
    _count_synapses(model)

    return model


def conv_mnist(
        beta_penalty=None,
        use_offset=False,
        l1_synapse_penalty=None,
        l2_decay=None):
    model = Sequential()

    def _add_conv_block(filters):
        model.add(Conv2D(
            filters, 3,
            padding="same",
            kernel_regularizer=_kernel_regularizer(
                l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    model.add(Conv2D(
        32, 3,
        padding="same",
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay),
        input_shape=(28, 28, 1)))
    _add_bn_relu(model, beta_penalty, use_offset)
    _add_conv_block(32)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(64)
    _add_conv_block(64)
    model.add(AveragePooling2D(pool_size=2))

    model.add(Flatten())
    model.add(Dense(
        128,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    _add_bn_relu(model, beta_penalty, use_offset)
    model.add(Dense(
        10,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(Softmax())

    _count_neurons(model)
    _count_synapses(model)

    return model


def dense_mnist(
        width, depth,
        pool=False,
        beta_penalty=None,
        use_offset=False,
        l1_synapse_penalty=None,
        l2_decay=None):
    model = Sequential()

    if pool:
        model.add(AveragePooling2D(pool_size=2, input_shape=(28, 28, 1)))
        model.add(Flatten())
    else:
        model.add(Flatten(input_shape=(28, 28, 1)))

    model.add(Dense(
        width,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    _add_bn_relu(model, beta_penalty, use_offset)

    for _ in range(depth - 1):
        model.add(Dense(
            width,
            kernel_regularizer=_kernel_regularizer(
                l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    model.add(Dense(
        10,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(Softmax())

    _count_neurons(model)
    _count_synapses(model)

    return model


def fcn32_voc(
        beta_penalty=None,
        use_offset=False,
        l1_synapse_penalty=None,
        l2_decay=None,
        pretrained=True):
    model = Sequential()

    def _add_conv_block(filters, kernel_size=3):
        model.add(Conv2D(
            filters, kernel_size,
            padding="same",
            kernel_regularizer=_kernel_regularizer(
                l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    model.add(Conv2D(
        64, 3,
        padding="same",
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay),
        input_shape=(voc_item_size() + (3,))))
    _add_bn_relu(model, beta_penalty, use_offset)
    _add_conv_block(64)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(128)
    _add_conv_block(128)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(256)
    _add_conv_block(256)
    _add_conv_block(256)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(512)
    _add_conv_block(512)
    _add_conv_block(512)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(512)
    _add_conv_block(512)
    _add_conv_block(512)
    model.add(AveragePooling2D(pool_size=2))

    _add_conv_block(4096, kernel_size=7)
    model.add(Conv2D(
        4096, 1,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(ReLU())
    model.add(Conv2D(
        VOC_CLASSES, 1,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(UpSampling2D(size=32, interpolation="bilinear"))
    # voc.voc_loss performs a softmax, so one is not needed here

    _count_neurons(model)
    _count_synapses(model)

    if not pretrained:
        return model

    # Copy weights from pretrained VGG (with some modification)
    vgg = VGG16(include_top=True, weights="imagenet")
    i = -1
    for vgg_layer in vgg.layers[:-1]:
        if layer_typename(vgg_layer) in ["Conv2D", "Dense"]:
            i += 1
            while layer_typename(model.layers[i]) != "Conv2D":
                i += 1
            fcn_layer = model.layers[i]
            vgg_weights = vgg_layer.get_weights()

            # Pretrained ImageNet weights expect BGR inputs with scale
            # of order 255
            if i == 0:
                vgg_weights[0] = 255.0 * np.flip(vgg_weights[0], axis=-2)

            for j, fcn_weight in enumerate(fcn_layer.get_weights()):
                vgg_weights[j] = vgg_weights[j].reshape(fcn_weight.shape)
            fcn_layer.set_weights(vgg_weights)

    return model


def mobilenet(
        beta_penalty=None,
        use_offset=False,
        l1_synapse_penalty=None,
        l2_decay=None):
    model = Sequential()

    def _add_depthwise_block(strides=1):
        model.add(DepthwiseConv2D(
            (3, 3),
            strides=strides,
            padding='same',
            kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    def _add_pointwise_block(filters):
        model.add(Conv2D(
            filters, (1, 1),
            kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
        _add_bn_relu(model, beta_penalty, use_offset)

    model.add(Conv2D(
        32, (3, 3),
        strides=2,
        padding='same',
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay),
        input_shape=(IMAGENET_H, IMAGENET_W, 3)))
    _add_bn_relu(model, beta_penalty, use_offset)
    _add_depthwise_block()
    _add_pointwise_block(64)

    _add_depthwise_block(strides=2)
    _add_pointwise_block(128)
    _add_depthwise_block()
    _add_pointwise_block(128)

    _add_depthwise_block(strides=2)
    _add_pointwise_block(256)
    _add_depthwise_block()
    _add_pointwise_block(256)

    _add_depthwise_block(strides=2)
    _add_pointwise_block(512)
    for _ in range(5):
        _add_depthwise_block()
        _add_pointwise_block(512)

    _add_depthwise_block(strides=2)
    _add_pointwise_block(1024)
    _add_depthwise_block()
    _add_pointwise_block(1024)

    model.add(AveragePooling2D(pool_size=(IMAGENET_H // 32, IMAGENET_W // 32)))

    model.add(Flatten())
    model.add(Dense(
        IMAGENET_CLASSES,
        kernel_regularizer=_kernel_regularizer(l1_synapse_penalty, l2_decay)))
    model.add(Softmax())

    _count_neurons(model)
    _count_synapses(model)

    return model


def _add_bn_relu(model, beta_penalty, use_offset):
    if beta_penalty is not None:
        regularizer = BatchNormSparsityRegularizer(beta_penalty)
        model.add(BatchNormalization(
            center=True,
            scale=False,
            beta_regularizer=regularizer))
    elif use_offset:
        model.add(BatchNormalization(center=False, scale=False))
        model.add(PostBatchNormOffset())
    else:
        model.add(BatchNormalization(center=True, scale=False))
    model.add(ReLU())


def _count_neurons(model):
    model_neurons = 0
    for layer in model.layers:
        if not hasattr(layer, "beta_regularizer"):
            continue
        if not isinstance(
                layer.beta_regularizer, BatchNormSparsityRegularizer):
            continue
        layer_neurons = np.prod(layer.output_shape[1:])
        layer.beta_regularizer.layer_neurons = layer_neurons
        model_neurons += layer_neurons
    for layer in model.layers:
        if not hasattr(layer, "beta_regularizer"):
            continue
        if not isinstance(
                layer.beta_regularizer, BatchNormSparsityRegularizer):
            continue
        layer.beta_regularizer.model_neurons = model_neurons


def _count_synapses(model):
    model_synapses = 0
    for layer in model.layers:
        if not hasattr(layer, "kernel_regularizer"):
            continue
        if not isinstance(layer.kernel_regularizer, L1WeightRegularizer):
            continue
        layer_synapses = count_layer_synapses(layer)
        layer.kernel_regularizer.layer_synapses = layer_synapses
        model_synapses += layer_synapses
    for layer in model.layers:
        if not hasattr(layer, "kernel_regularizer"):
            continue
        if not isinstance(layer.kernel_regularizer, L1WeightRegularizer):
            continue
        layer.kernel_regularizer.model_synapses = model_synapses


def _kernel_regularizer(l1_synapse_penalty, l2_decay):
    if l1_synapse_penalty is not None:
        return L1WeightRegularizer(l1_synapse_penalty, l2_decay=l2_decay)
    elif l2_decay is not None:
        return l2(l=l2_decay)
    else:
        return None
