import tensorflow as tf
import numpy as np
import os, h5py, tfcv

class LoadWeightsException(Exception):
    pass

def load_h5(file, model, convert_name, ignore=None):
    with h5py.File(file, "r") as f:
        keys = []
        f.visit(keys.append)
        all_weights = {key: np.asarray(f[key]) for key in keys if ":" in key}
    for var in model.variables:
        key = convert_name(var.name)
        if ignore is None or not ignore(var.name, key):
            if not key in all_weights:
                raise LoadWeightsException(f"Variable {key} not found in {os.path.basename(file)}")
            weights = all_weights[key]
            if weights.shape != var.shape:
                raise LoadWeightsException(f"Variable {key} expected shape {var.shape} but got shape {weights.shape}")
            var.assign(weights)
        if key in all_weights:
            del all_weights[key]
    keys = list(all_weights.keys())
    if len(keys) > 0:
        raise LoadWeightsException(f"Failed to find variable for weights: {keys}")

def load_ckpt(file, model, convert_name): # TODO: ignore layers?
    for v in model.variables:
        key = convert_name(v.name)
        new_var = tf.train.load_variable(file, key)

        assert np.all(np.squeeze(v).shape == np.squeeze(new_var).shape)
        new_var = new_var.reshape(v.shape)
        v.assign(new_var)

from collections import defaultdict
def auto_convert_name(loaded_weights, model):
    loaded_weights_ordered = [(k, loaded_weights[k]) for k in loaded_weights.keys()]
    def map_weight(layer, weight, weight_index):
        if weight_index == 0 and (isinstance(layer, tf.keras.layers.Conv1D) or isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Conv3D)):
            return np.transpose(weight, [-1, -2] + list(range(len(weight.shape) - 2)))
        else:
            return weight
    model_weights_ordered = [((layer, weight_index), map_weight(layer, weight, weight_index)) for layer in model.layers for weight_index, weight in enumerate(layer.get_weights()) if len(layer.get_weights()) > 0]

    mapping = {}


    in_shapes = defaultdict(list)
    for name, weight in loaded_weights_ordered:
        shape = tuple(s for s in weight.shape)
        in_shapes[shape].append(name)

    out_shapes = defaultdict(list)
    for (layer, weight_index), weight in model_weights_ordered:
        shape = tuple(s for s in weight.shape)
        out_shapes[shape].append((layer, weight_index))

    print(f"in_shapes={list(in_shapes.keys())}")
    print(f"out_shapes={list(out_shapes.keys())}")

    in_not_out_shapes = set(in_shapes.keys()).difference(set(out_shapes.keys()))
    out_not_in_shapes = set(out_shapes.keys()).difference(set(in_shapes.keys()))
    print(f"in_not_out_shapes={[(s, in_shapes[s]) for s in in_not_out_shapes]}")
    print(f"out_not_in_shapes={[(s, [l.name for l, w in out_shapes[s]]) for s in out_not_in_shapes]}")

    unique_in_shapes = [s for s in in_shapes.keys() if len(in_shapes[s]) == 1]
    unique_out_shapes = [s for s in out_shapes.keys() if len(out_shapes[s]) == 1]
    print(f"unique_in_shapes={unique_in_shapes}")
    print(f"unique_out_shapes={unique_out_shapes}")

    unique_shapes = set(unique_in_shapes).intersection(set(unique_out_shapes))
    print(f"unique_shapes={unique_shapes}")

    for shape in unique_shapes:
        layer, weight_index = out_shapes[shape][0]
        in_name = in_shapes[shape][0]

        if isinstance(layer, tf.keras.layers.Conv1D) or isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Conv3D):
            if weight_index == 0:
                print(f"Found unique shape of conv kernel with in={in_name} and out={layer.name}")
                mapping[(layer.name, 0)] = in_name

                if not layer.bias is None:
                    for index, (name, weight) in enumerate(loaded_weights_ordered):
                        if name == in_name:
                            break
                    else:
                        assert False
                    index += 1
                    if index >= len(loaded_weights_ordered):
                        raise ValueError("Expected conv bias after conv kernel")
                    next_name, next_weight = loaded_weights_ordered[index]
                    if np.all(np.squeeze(next_weight.shape) == np.squeeze(layer.get_weights()[1].shape)):
                        print(f"    Next loaded weight {next_name} is bias of {layer.name}")
                        mapping[(layer.name, 1)] = next_name
                    else:
                        raise ValueError("Expected conv bias after conv kernel")


            elif weight_index == 1:
                print(f"Found unique shape of conv bias with in={in_name} and out={layer.name}")

                for index, (name, weight) in enumerate(loaded_weights_ordered):
                    if name == in_name:
                        break
                else:
                    assert False
                index -= 1
                if index < 0:
                    raise ValueError("Expected conv kernel before conv bias")
                prev_name, prev_weight = loaded_weights_ordered[index]
                if np.all(np.squeeze(prev_weight.shape) == np.squeeze(layer.get_weights()[0].shape)):
                    print(f"    Previous loaded weight {prev_name} is kernel of {layer.name}")
                    mapping[(layer.name, 0)] = prev_name
                else:
                    raise ValueError("Expected conv kernel before conv bias")
            else:
                assert False
    sys.exit()

def load_pth(file, model, convert_name, ignore=None, map={}):
    import torch

    all_weights = torch.load(file, map_location=torch.device("cpu"))
    if "state_dict" in all_weights:
        all_weights = all_weights["state_dict"]
    elif "model_state" in all_weights:
        all_weights = all_weights["model_state"]
    elif "model" in all_weights:
        all_weights = all_weights["model"]

    # convert_name = auto_convert_name(all_weights, model)

    # for k in all_weights.keys():
    #     print(k)
    # print()
    # for layer in model.layers:
    #     if len(layer.get_weights()) > 0:
    #         print(layer.name)
    # sys.exit(-1)

    def get_weight(keys, default_mapper=lambda x: x):
        if not isinstance(keys, list):
            keys = [keys]
        keys2 = [k for k in keys if k in all_weights]
        if len(keys2) == 0:
            raise LoadWeightsException(f"Variable {keys} not found in {os.path.basename(file)}")
        if len(keys2) > 1:
            raise LoadWeightsException(f"More than one input matches variable {keys} in {os.path.basename(file)}")
        key = keys2[0]

        result = all_weights[key]
        del all_weights[key]
        result = np.asarray(result)

        if key in map:
            result = map[key](result)
        else:
            result = default_mapper(result)

        return result
    def set_weights(layer, pth_name, weights):
        layer_weights = layer.get_weights()
        for i in range(len(weights)):
            if np.all(layer_weights[i].shape == weights[i].shape):
                pass
            elif np.all(np.squeeze(layer_weights[i]).shape == np.squeeze(weights[i]).shape):
                # print(f"Warning: Reshaping weights in layer {layer.name} from {weights[i].shape} to {layer_weights[i].shape}")
                weights[i] = np.reshape(weights[i], layer_weights[i].shape)
            else:
                raise LoadWeightsException(f"Layer {layer.name} with weight shapes {layer_weights[i].shape} got invalid weight shapes {weights[i].shape} from pth variable {pth_name}")
        layer.set_weights(weights)

    for layer in model.layers:
        if len(layer.get_weights()) > 0:
            key = convert_name(layer.name)

            if isinstance(layer, tf.keras.layers.Conv1D) or isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Conv3D):
                weights = get_weight([key + ".weight", key + "_weight"], lambda w: np.transpose(w, list(range(2, len(w.shape))) + [1, 0]))
                if len(weights.shape) != len(layer.get_weights()[0].shape):
                    min_len = min(len(weights.shape), len(layer.get_weights()[0].shape))
                    if np.any(weights.shape[-min_len:] != layer.get_weights()[0].shape[-min_len:]):
                        raise LoadWeightsException(f"Convolution layer {layer.name} with kernel shape {layer.get_weights()[0].shape} got invalid loaded kernel shape {weights.shape}")
                    else:
                        while len(weights.shape) > len(layer.get_weights()[0].shape):
                            weights = weights[0]
                        while len(weights.shape) < len(layer.get_weights()[0].shape):
                            weights = np.expand_dims(weights, axis=0)
                if not layer.bias is None:
                    bias = get_weight([key + ".bias", key + "_bias"])
                    set_weights(layer, key, [weights, bias])
                else:
                    if (key + ".bias") in all_weights:
                        raise LoadWeightsException(f"Convolution layer {layer.name} does not have bias, but found bias in weights file")
                    set_weights(layer, key, [weights])
            elif isinstance(layer, tf.keras.layers.BatchNormalization):
                weights = get_weight(key + ".weight")
                bias = get_weight(key + ".bias")
                running_mean = get_weight(key + ".running_mean")
                running_var = get_weight(key + ".running_var")
                set_weights(layer, key, [weights, bias, running_mean, running_var])
            elif isinstance(layer, tf.keras.layers.LayerNormalization):
                weights = get_weight(key + ".weight")
                bias = get_weight(key + ".bias")
                set_weights(layer, key, [weights, bias])
            elif isinstance(layer, tf.keras.layers.Embedding):
                weights = get_weight(key)[0]
                set_weights(layer, key, [weights])
            elif isinstance(layer, tfcv.model.util.ScaleLayer):
                weights = get_weight(key)
                set_weights(layer, key, [weights])
            else:
                raise LoadWeightsException(f"Invalid type of layer {layer.name}")
    for key in list(all_weights.keys()):
        if "num_batches_tracked" in key or (not ignore is None and ignore(key)):
            del all_weights[key]
    keys = list(all_weights.keys())
    if len(keys) > 0:
        raise LoadWeightsException(f"Failed to find layer for torch variables: {keys}")
