import torch
from benchmark import benchmark_eval
from competitors_models.DDRNet_Reimplementation import get_ddrnet_23, \
        get_ddrnet_23slim, get_ddrnet_39
from competitors_models.bisenetv1 import BiSeNetV1
from competitors_models.bisenetv2 import BiSeNetV2
from competitors_models.hardnet import hardnet
from competitors_models.stdc.model_stages import get_stdc
from competitors_models.hyperseg.hyperseg_v1_0 import get_hyperseg
from competitors_models.sfnet_dfnet.sfnet_dfnet import AlignedDFnetv2
from competitors_models.sfnet_dfnet.sfnet_resnet import DeepR18_SF_deeply
from competitors_models.lps import get_lspnet_l,get_lspnet_m,get_lspnet_s
from competitors_models.pidnet.pidnet import get_pidnet_s,get_pidnet_m,get_pidnet_l
from model import RegSeg
import timm

def calculate_competitor_fps():

    mixed_precision=True
    tensorrt="none"

    model_dic = {
        "efficientnet_b0": timm.create_model('efficientnet_b0'),
        "regnet006": timm.create_model('regnety_006'),
        "exp48_decoder26": RegSeg("exp48_decoder26", num_classes=19),
        "mobilenetv3_lraspp": RegSeg("mobilenetv3_lraspp", num_classes=19),
        "ddrnet23_slim": get_ddrnet_23slim(),
        "ddrnet23": get_ddrnet_23(),
        "ddrnet39": get_ddrnet_39(),
        "pidnet-s":get_pidnet_s(),
        "pidnet-m":get_pidnet_m(),
        "pidnet-l":get_pidnet_l(),
        "hardnet": hardnet(19),
        "sfnet_dfv2": AlignedDFnetv2(19),
        "sfnet_resnet18": DeepR18_SF_deeply(19),
        "stdc1": get_stdc("STDC1"),
        "stdc2": get_stdc("STDC2"),
        "bisenetv1": BiSeNetV1(19),
        "bisenetv2": BiSeNetV2(19),
        "hyperseg": get_hyperseg(),
        "lsp-l":get_lspnet_l(),
    }

    res1024_models = [
        "exp48_decoder26", "mobilenetv3_lraspp", "efficientnet_b0","regnet006",
        "ddrnet23_slim","ddrnet23", "ddrnet39","pidnet-s","pidnet-m","pidnet-l",
        "lsp-l","hardnet", "sfnet_dfv2","sfnet_resnet18"]
    res768_models=["stdc1","stdc2","bisenetv1"]
    res512_models=["bisenetv2","hyperseg"]

    res1024_model_dic={name: model for name, model in model_dic.items() if
         name in res1024_models}
    res768_model_dic = {name: model for name, model in model_dic.items() if
                        name in res768_models}
    res512_model_dic = {name: model for name, model in model_dic.items() if
                        name in res512_models}

    x = torch.randn(1, 3, 1024, 2048)
    benchmark_eval(res1024_model_dic,x, mixed_precision,tensorrt)

    x = torch.randn(1, 3, 768, 1536)
    benchmark_eval(res768_model_dic, x, mixed_precision,tensorrt)

    x = torch.randn(1, 3, 512, 1024)
    benchmark_eval(res512_model_dic, x, mixed_precision,tensorrt)
def sanity_check():
    model1=get_pidnet_m().train()
    model2=get_lspnet_l().train()
    x=torch.randn(4,3,1000,1000)
    print(model1(x).shape)
    print(model2(x).shape)

if __name__=="__main__":
    sanity_check()
    # calculate_competitor_fps()
