from . import simnet, alexnet, vgg, resnet, \
    senet, resnext, densenet, simplenetv1, \
    efficientnetv2, googlenet, xception, mobilenetv2, \
    inceptionv3, wideresnet, shufflenetv2, squeezenet, mnasnet, \
    convnext


def load_backbone(model_name, in_channels=3):
    model, out_channel = None, None
    if model_name == 'alexnet':
        model, out_channel = alexnet.alexnet(in_channels), 256
    elif model_name == 'vgg16':
        model, out_channel = vgg.vgg16_bn(in_channels), 512
    elif model_name == 'resnet34':
        model, out_channel = resnet.resnet34(in_channels), 512
    elif model_name == 'resnet50':
        model, out_channel = resnet.resnet50(in_channels), 2048
    elif model_name == 'senet50':
        model, out_channel = senet.seresnet50(in_channels), 2048
    elif model_name == 'wideresnet28':
        model, out_channel = wideresnet.wide_resnet28_10(in_channels), 640
    elif model_name == 'resnext50':
        model, out_channel = resnext.resnext50(in_channels), 2048
    elif model_name == 'densenet121':
        model, out_channel = densenet.densenet121(in_channels), 1024
    elif model_name == 'efficientnetv2s':
        model, out_channel = efficientnetv2.effnetv2_s(in_channels), 1792
    elif model_name == 'efficientnetv2l':
        model, out_channel = efficientnetv2.effnetv2_l(in_channels), 1792
    elif model_name == 'googlenet':
        model, out_channel = googlenet.googlenet(in_channels), 1024
    elif model_name == 'xception':
        model, out_channel = xception.xception(in_channels), 2048
    elif model_name == 'mobilenetv2':
        model, out_channel = mobilenetv2.mobilenetv2(in_channels), 1280
    elif model_name == 'convnextb':
        model, out_channel = convnext.convnext_base(in_chans = in_channels), 1024
    # elif model_name == 'inceptionv3':
    #     model = inceptionv3.inceptionv3(in_channels, num_classes) 
    # elif model_name == 'shufflenetv2':
    #     model = shufflenetv2.shufflenetv2(in_channels, num_classes)
    # elif model_name == 'squeezenet':
    #     model = squeezenet.squeezenet(in_channels, num_classes)
    # elif model_name == 'mnasnet':
    #     model = mnasnet.mnasnet(in_channels, num_classes)
    # elif model_name == 'simplenetv1':
    #     model = simplenetv1.simplenet(in_channels, num_classes)

    
    return model, out_channel
