from losses import BootstrappedCE,NormalizedCE,WeightedCE,WeightedCE2,DownsampledCE,AuxLoss
from lr_schedulers import poly_lr_scheduler,cosine_lr_scheduler,step_lr_scheduler,exp_lr_scheduler
from data import get_cityscapes,get_pascal_voc,get_camvid, build_val_transform,Cityscapes,get_mapillary,get_coco,get_custom
from model import RegSeg
import torch
from competitors_models.hardnet import hardnet
from competitors_models.DDRNet_Reimplementation import get_ddrnet_23,get_ddrnet_23slim
from competitors_models.lps import get_lspnet_l,get_lspnet_m,get_lspnet_s
from competitors_models.pidnet.pidnet import get_pidnet_s,get_pidnet_m,get_pidnet_l
from competitors_models.stdc.model_stages import get_stdc

def get_lr_function(config,total_iterations):
    # get the learning rate multiplier function for LambdaLR
    name=config["lr_scheduler"]
    warmup_iters=config["warmup_iters"]
    warmup_factor=config["warmup_factor"]
    if "poly"==name:
        p=config["poly_power"]
        return lambda x : poly_lr_scheduler(x,total_iterations,warmup_iters,warmup_factor,p)
    elif "cosine"==name:
        return lambda x : cosine_lr_scheduler(x,total_iterations,warmup_iters,warmup_factor)
    elif "step"==name:
        return lambda x : step_lr_scheduler(x,total_iterations,warmup_iters,warmup_factor)
    elif "exp"==name:
        beta=config["exp_beta"]
        return lambda x : exp_lr_scheduler(x,total_iterations,warmup_iters,warmup_factor,beta)
    else:
        raise NotImplementedError()

def get_loss_fun(config):
    train_crop_size=config["train_crop_size"]
    ignore_value=config["ignore_value"]
    if isinstance(train_crop_size,int):
        crop_h,crop_w=train_crop_size,train_crop_size
    else:
        crop_h,crop_w=train_crop_size
    loss_type="cross_entropy"
    if "loss_type" in config:
        loss_type=config["loss_type"]
    if loss_type=="cross_entropy":
        loss_fun=torch.nn.CrossEntropyLoss(ignore_index=ignore_value)
    elif loss_type=="normalized_cross_entropy":
        loss_fun=NormalizedCE(ignore_index=ignore_value)
    elif loss_type=="bootstrapped":
        # 8*768*768/16
        minK=int(config["batch_size"]*crop_h*crop_w/16)
        print(f"bootstrapped minK: {minK}")
        loss_fun=BootstrappedCE(minK,0.3,ignore_index=ignore_value)
    elif loss_type=="aux_loss":
        minK=int(config["batch_size"]*crop_h*crop_w/16)
        print(f"bootstrapped minK: {minK}")
        loss_fun1=BootstrappedCE(minK,0.3,ignore_index=ignore_value)
        loss_fun2=BootstrappedCE(minK,0.3,ignore_index=ignore_value)
        aux_weight=0.4
        loss_fun=AuxLoss(loss_fun1,loss_fun2,aux_weight)
    elif loss_type=="weighted_cross_entropy":
        weight=config["class_weight"]
        weight=torch.tensor(weight).cuda()
        loss_fun=WeightedCE(ignore_value,weight)
    elif loss_type=="weighted_cross_entropy2":
        weight=config["class_weight"]
        weight=torch.tensor(weight).cuda()
        loss_fun=WeightedCE2(ignore_value,weight)
    elif loss_type=="downsampled_cross_entropy":
        loss_fun=DownsampledCE(ignore_value)
    elif loss_type=="normalized_downsampled_cross_entropy":
        loss_fun=DownsampledCE(ignore_value,True)
    else:
        raise NotImplementedError()
    return loss_fun

def get_optimizer(model,config):
    p_bn=[] # batchnorm
    p_dr=[] # dilation rate
    p_default=[] # everything else
    for n, p in model.named_parameters():
        if "bn" in n:
            p_bn.append(p)
        elif "dilation_rate" in n:
            p_dr.append(p)
        else:
            p_default.append(p)
    wd=config["weight_decay"]
    lr=config["lr"]
    bn_wd=0 if not config["bn_weight_decay"] else wd
    optim_params = [
        {"params": p_bn, "weight_decay": bn_wd},
        {"params": p_dr, "weight_decay": 0, "lr": lr*2},
        {"params": p_default, "weight_decay": wd}
    ]
    # if not config["bn_weight_decay"]:
    #     p_bn = [p for n, p in model.named_parameters() if "bn" in n]
    #     p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
    #     optim_params = [
    #         {"params": p_bn, "weight_decay": 0},
    #         {"params": p_non_bn, "weight_decay": config["weight_decay"]},
    #     ]
    # else:
    #     optim_params = model.parameters()
    return torch.optim.SGD(
        optim_params,
        lr=lr,
        momentum=config["momentum"],
        weight_decay=wd
    )

def get_val_dataset(config):
    val_input_size=config["val_input_size"]
    val_label_size=config["val_label_size"]

    root=config["dataset_dir"]
    name=config["dataset_name"]
    val_split=config["val_split"]
    if name=="cityscapes":
        val_transform=build_val_transform(val_input_size,val_label_size)
        val = Cityscapes(root, split=val_split, target_type="semantic",
                         transforms=val_transform, class_uniform_pct=0)
    else:
        raise NotImplementedError()
    return val

def get_dataset_loaders(config):
    name=config["dataset_name"]
    if name=="cityscapes":
        train_loader, val_loader,train_set=get_cityscapes(
            config["dataset_dir"],
            config["batch_size"],
            config["train_min_size"],
            config["train_max_size"],
            config["train_crop_size"],
            config["val_input_size"],
            config["val_label_size"],
            config["aug_mode"],
            config["class_uniform_pct"],
            config["train_split"],
            config["val_split"],
            config["num_workers"],
            config["ignore_value"]
        )
    elif name=="camvid":
        train_loader, val_loader,train_set=get_camvid(
            config["dataset_dir"],
            config["batch_size"],
            config["train_min_size"],
            config["train_max_size"],
            config["train_crop_size"],
            config["val_input_size"],
            config["val_label_size"],
            config["aug_mode"],
            config["train_split"],
            config["val_split"],
            config["num_workers"],
            config["ignore_value"]
        )
    elif name=="coco":
        train_loader, val_loader,train_set=get_coco(
            config["dataset_dir"],
            config["batch_size"],
            config["train_min_size"],
            config["train_max_size"],
            config["train_crop_size"],
            config["val_input_size"],
            config["val_label_size"],
            config["aug_mode"],
            config["num_workers"],
            config["ignore_value"]
        )
    elif name=="mapillary":
        train_loader, val_loader,train_set=get_mapillary(
            config["dataset_dir"],
            config["batch_size"],
            config["train_min_size"],
            config["train_max_size"],
            config["train_crop_size"],
            config["val_input_size"],
            config["val_label_size"],
            config["aug_mode"],
            config["num_workers"],
            config["ignore_value"],
            config["mapillary_reduced"]
        )
    elif name=="synthetic":
        train_loader, val_loader, train_set = get_custom(
            config["dataset_dir"],
            config["batch_size"],
            config["train_min_size"],
            config["train_max_size"],
            config["train_crop_size"],
            config["val_input_size"],
            config["val_label_size"],
            config["aug_mode"],
            config["train_split"],
            config["val_split"],
            config["num_workers"],
            config["ignore_value"]
        )
    else:
        raise NotImplementedError()
    print("train size:", len(train_loader))
    print("val size:", len(val_loader))
    return train_loader, val_loader,train_set


def get_model(config):
    pretrained_backbone=config["pretrained_backbone"]
    if config["resume"]:
        pretrained_backbone=False
    model_type=config["model_type"]
    if model_type=="experimental2" or model_type=="regseg":
        ablate_decoder=False
        if "ablate_decoder" in config:
            ablate_decoder=config["ablate_decoder"]
        aux=False
        if "aux_loss" in config:
            aux=config["aux_loss"]
        change_num_classes=False
        if "change_num_classes" in config:
            change_num_classes=config["change_num_classes"]
        downsample=False
        if "downsampled" in config["loss_type"]:
            downsample=True
        return RegSeg(
            name=config["model_name"],
            num_classes=config["num_classes"],
            pretrained=config["pretrained_path"],
            ablate_decoder=ablate_decoder,
            change_num_classes=change_num_classes,
            downsample=downsample,
            aux=aux
        )
    elif model_type=="competitor":
        if config["model_name"]=="hardnet":
            return hardnet(config["num_classes"])
        elif config["model_name"]=="ddrnet23":
            return get_ddrnet_23(config["num_classes"])
        elif config["model_name"]=="ddrnet23slim":
            return get_ddrnet_23slim(config["num_classes"])
        elif config["model_name"]=="pidnet-s":
            return get_pidnet_s(config["num_classes"])
        elif config["model_name"] == "pidnet-m":
            return get_pidnet_m(config["num_classes"])
        elif config["model_name"] == "pidnet-l":
            return get_pidnet_l(config["num_classes"])
        elif config["model_name"] == "lspnet-l":
            return get_lspnet_l(config["num_classes"])
        elif config["model_name"] == "stdc2":
            return get_stdc("STDC2")
        elif config["model_name"] == "stdc1":
            return get_stdc("STDC1")
        else:
            raise NotImplementedError()
    else:
        raise NotImplementedError()
