from .mobilenet_v2_sd import mobilenetv2
# noqa: F401
from .mobilenet import MobileNet
from .resnet_sd import ResNet10, ResNet18, ResNet34, ResNet50
from .resnet import (  # noqa: F401
    resnet18, resnet26, resnet34, resnet50, resnet10,
    resnet101, resnet152, resnet_custom
)
from .resnet_dyconv import DyResNet101, DyDyResNet50, DyResNet152, DyResNet10, DyResNet18, DyResNet34

# from .senet import se_resnext50_32x4d, se_resnext101_32x4d  # noqa: F401
# from .densenet import densenet121, densenet169, densenet201, densenet161  # noqa: F401
# from .toponet import toponet_conv, toponet_sepconv, toponet_mb
# from .hrnet import HRNet  # noqa: F401
# from .mnasnet import mnasnet  # noqa: F401
# from .nas_zoo import (  # noqa: F401
#     mbnas_t29_x0_84, mbnas_t47_x1_00, supnas_t18_x1_00, supnas_t37_x0_92, supnas_t44_x1_00,
#     supnas_t66_x1_11, supnas_t100_x0_96, nas_custom
# )
# from .nas_backbone import nas_backbone  # noqa: F401
# from .resnet_official import (  # noqa: F401
#     resnet18_official, resnet34_official, resnet50_official, resnet101_official, resnet152_official,
#     resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
# )
# from .condconv_resnet import (  # noqa: F401
#     resnet18_condconv_shared, resnet18_condconv_independent,
#     resnet34_condconv_shared, resnet34_condconv_independent,
#     resnet50_condconv_shared, resnet50_condconv_independent,
#     resnet101_condconv_shared, resnet101_condconv_independent,
#     resnet152_condconv_shared, resnet152_condconv_independent
# )
# from .condconv_mobilenet_v2 import (  # noqa: F401
#     mobilenetv2_condconv_pointwise, mobilenetv2_condconv_independent, mobilenetv2_condconv_shared
# )
# from .mobilenet_v3 import mobilenet_v3  # noqa: F401
# from .ghostnet import ghostnet  # noqa: F401
# from .resnest import resnest50, resnest101, resnest200, resnest269  # noqa: F401
# from .ibnnet import resnet50_ibn_a, resnet101_ibn_a, resnet152_ibn_a  # noqa: F401
# from .fbnet_v2 import fbnetv2_f1, fbnetv2_f4, fbnetv2_l2_hs, fbnetv2_l3  # noqa: F401
# from .weightnet_shufflenet_v2 import (  # noqa: F401
#     shufflenet_v2_x0_5_weightnet, shufflenet_v2_x1_0_weightnet,
#     shufflenet_v2_x1_5_weightnet, shufflenet_v2_x2_0_weightnet
# )
# from .mobilenext import mobilenext  # noqa: F401
# from .dmcp_resnet import (  # noqa: F401
#     dmcp_resnet18_45M, dmcp_resnet18_47M, dmcp_resnet18_51M, dmcp_resnet18_480M, dmcp_resnet50_282M,
#     dmcp_resnet50_1100M, dmcp_resnet50_2200M
# )
# from .bignas_resnet_basicblock import (  # noqa: F401
#     bignas_resnet18_9M, bignas_resnet18_37M, bignas_resnet18_50M,
#     bignas_resnet18_49M, bignas_resnet18_66M, bignas_resnet18_1555M,
#     bignas_resnet18_107M, bignas_resnet18_125M, bignas_resnet18_150M,
#     bignas_resnet18_312M, bignas_resnet18_403M, bignas_resnet18_492M
# )
# from .bignas_resnet_bottleneck import (  # noqa: F401
#     bignas_resnet50_2954M, bignas_resnet50_3145M, bignas_resnet50_3811M
# )
#
import torch
from collections import OrderedDict

load_path = {
    'resnet18': '/mnt/lustre/liyuhang1/ImageNet/ckpts/resnet18_fp_strongbaseline/ckpt.pth.tar',
    'resnet34': '/mnt/lustre/liyuhang1/ImageNet/ckpts/resnet34/res34.pth.tar',
    # 'resnet50': '/mnt/lustre/liyuhang1/ImageNet/ckpts/resnet50_fp_strongbaseline/ckpt.pth.tar',
    'resnet50': '/mnt/lustre/share/prototype_model_zoo/resnet50_batch1k_epoch100_nesterov_wd0.0001/checkpoints/ckpt.pth.tar',
    'resnet50ad': '/mnt/lustre/share/prototype_model_zoo/resnet50ad_batch1k_epoch200_coslr_nesterov_wd0.0001_mixup0.2_fp16/checkpoints/ckpt.pth.tar',
    'resnet101ad': '/mnt/lustre/share/prototype_model_zoo/resnet101ad_batch1k_epoch200_coslr_nesterov_wd0.0001_mixup0.2_fp16/checkpoints/ckpt.pth.tar',
    'resnet152ad': '/mnt/lustre/share/prototype_model_zoo/resnet152ad_batch1k_epoch200_coslr_nesterov_wd0.0001_mixup0.2_fp16/checkpoints/ckpt.pth.tar',
    'regnetx_200m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_200m.pth.tar',
    'regnetx_400m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_400m.pth.tar',
    'regnetx_600m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_600m.pth.tar',
    'regnetx_800m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_800m.pth.tar',
    'regnetx_1600m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_1600m.pth.tar',
    'regnetx_3200m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_3200m.pth.tar',
    'regnetx_4000m': '/mnt/lustre/share/liyuhang1/ImageNet/ckpts/regnet_4000m.pth.tar',
    'mobilenet_v2_1.0': '/mnt/lustre/share/prototype_model_zoo/sd_res18_0.1_phi_snip10/checkpoints/ckpt.pth.tar',
    'mobilenet_v2_2.0': '/mnt/lustre/share/prototype_model_zoo/mbv2_2.0_batch1k_epoch200_coslr_nesterov_wd0.00004_bn_nowd_fp16/checkpoints/ckpt.pth.tar',
    'mnasnet_1.0': '/mnt/lustre/share/prototype_model_zoo/mnasnet_1.0_batch1k_epoch200_coslr_nesterov_wd0.00004_bn_nowd/checkpoints/ckpt.pth.tar',
    'mnasnet_2.0': '/mnt/lustre/share/prototype_model_zoo/mnasnet_2.0_batch1k_epoch200_coslr_nesterov_wd0.00004_bn_nowd/checkpoints/ckpt.pth.tar',
    'vgg16_bn': '/mnt/lustre/liyuhang1/ImageNet/ckpts/vgg16_bn-6c64b313.pth',
    'se_resnext50_32x4d': '/mnt/lustre/share/prototype_model_zoo/se_resnext50_32x4d_batch1k_epoch100_coslr_nesterov_wd0.0001_bn_nowd/checkpoints/ckpt.pth.tar',
    'densenet121': '/mnt/lustre/share/prototype_model_zoo/densenet121_batch1k_epoch100_nesterov_wd0.0001/checkpoints/ckpt.pth.tar',
    'densenet201': '/mnt/lustre/share/prototype_model_zoo/densenet201_batch1k_epoch100_nesterov_wd0.0001_fp16/checkpoints/ckpt.pth.tar',
    'shufflenet_v2_x1_5': '/mnt/lustre/share/prototype_model_zoo/shufflev2_1.5_batch1k_epoch300_coslr_nesterov_wd0.00004_bn_nowd/checkpoints/ckpt.pth.tar',
    'shufflenet_v2_x2_0': '/mnt/lustre/share/prototype_model_zoo/shufflev2_2.0_batch1k_epoch300_coslr_nesterov_wd0.00004_bn_nowd/checkpoints/ckpt.pth.tar',
    'mobilenet_v3_1.0': '/mnt/lustre/share/prototype_model_zoo/mobilenet_v3_exp/mbv3_large_1.0_batch1k_epoch350_coslr_nesterov_wd0.00003_bn_nowd_fp16_ema0.9999_dropout0.2/checkpoints/ckpt_360000.pth.tar',
    'mobilenet_v3_1.4': '/mnt/lustre/share/prototype_model_zoo/mobilenet_v3_exp/mbv3_large_1.4_batch1k_epoch350_coslr_nesterov_wd0.00003_bn_nowd_fp16_ema0.9999_dropout0.2/checkpoints/cktp.pth.tar',
    'supnas_t44_x1_00': '/mnt/lustre/share/prototype_model_zoo/nas_sup_t44_batch2k_epoch300_coslr_nesterov_wd0.0001_bn_nowd/checkpoints/ckpt_187000.pth.tar',
    'supnas_t66_x1_11': '/mnt/lustre/share/prototype_model_zoo/nas_sup_t66_batch2k_epoch300_coslr_nesterov_wd0.0001_bn_nowd/checkpoints/ckpt.pth.tar',
    'mobilenetv1': "/mnt/lustre/liyuhang1/SNN/raw/mobilenetv1b_imagenet.pth.tar",
    'efficientnet_b0': '/mnt/lustre/share/prototype_model_zoo/effnet_b0_batch1k_epoch350_stepdecaylr_bn_nowd_fp16_rmsprop/checkpoints/ckpt_437000.pth.tar',
    'vgg11_bn': '/mnt/lustre/liyuhang1/ImageNet/ckpts/vgg11_bn-6002323d.pth',
}


def load_model_pytorch(model, load_model, replace_dict={}):

    checkpoint = torch.load(load_model, map_location='cpu')

    if 'state_dict' in checkpoint.keys():
        load_from = checkpoint['state_dict']
    elif 'model' in checkpoint.keys():
        load_from = checkpoint['model']
    else:
        load_from = checkpoint

    # remove "module." in case the model is saved with Dist Mode
    if 'module.' in list(load_from.keys())[0]:
        load_from = OrderedDict([(k.replace("module.", ""), v) for k, v in load_from.items()])
    for keys in replace_dict.keys():
        load_from = OrderedDict([(k.replace(keys, replace_dict[keys]), v) for k, v in load_from.items()])

    model.load_state_dict(load_from, strict=False)


def model_entry(config, pretrained=False):

    if config['type'] not in globals():
        if config['type'].startswith('spring_'):
            try:
                from spring.models import SPRING_MODELS_REGISTRY
            except ImportError:
                print('Please install Spring2 first!')
            model_name = config['type'][len('spring_'):]
            config['type'] = model_name
            return SPRING_MODELS_REGISTRY.build(config)
        # else:
            # from prototype.spring import PrototypeHelper
            # return PrototypeHelper.external_model_builder[config['type']](**config['kwargs'])

    model = globals()[config['type']](**config['kwargs'])

    if pretrained:
        load_model_pytorch(model, load_path[config['type']])

    return model
