from blocks import *
from competitor_blocks import BiseNetDecoder,SFNetDecoder,FaPNDecoder
from benchmark import benchmark_eval,benchmark_train,benchmark_memory

class RegSeg(nn.Module):
    # exp48_decoder26 is what we call RegSeg in our paper
    # exp53_decoder29 is a larger version of exp48_decoder26
    # all the other models are for ablation studies
    def __init__(self, name, num_classes, pretrained="", ablate_decoder=False,change_num_classes=False,downsample=False,aux=False):
        # downsample means using the DownsampleCE loss function
        super().__init__()
        self.stem=ConvBnAct(3,32,3,2,1)
        self.downsample=downsample
        body_name, decoder_name=name.split("_")
        if "exp30" == body_name:
            self.body=RegSegBody(5*[[1,4]]+8*[[1,10]])
        elif "exp43"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4],[1,6],[1,8],[1,10]]+7*[[1,12]])
        elif "exp46"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4],[1,6],[1,8]]+8*[[1,10]])
        elif "exp47"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4],[1,6],[1,8],[1,10],[1,12]]+6*[[1,14]])
        elif "exp48"==body_name:
            self.body=RegSegBody([[1],[1,2]]+4*[[1,4]]+7*[[1,14]])
        elif "exp49"==body_name:
            self.body=RegSegBody([[1],[1,2]]+6*[[1,4]]+5*[[1,6,12,18]])
        elif "exp50"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4],[1,6],[1,8],[1,10]]+7*[[1,3,6,12]])
        elif "exp51"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4],[1,6],[1,8],[1,10]]+7*[[1,4,8,12]])
        elif "exp52"==body_name:
            self.body=RegSegBody([[1],[1,2],[1,4]]+10*[[1,6]])
        elif "exp53"==body_name:
            self.body=RegSegBody2([[1],[1,2]]+4*[[1,4]]+7*[[1,14]])
        elif "exp55"==body_name:
            # d2=ceil(d2)
            # d1=floor(d1)
            # learn2_decoder26
            self.body=RegSegBody([[1], [1, 2], [1, 2], [1, 3], [2, 3], [2, 7], [2, 3], [2, 6], [2, 5], [2, 9], [2, 11], [4, 7], [5, 14]])
        elif "exp56"==body_name:
            # output stride 32
            # L2
            self.body=RegSegBody3([[1], [1, 2], [1, 2], [1, 2], [1], [1, 2], [1, 2], [3, 5], [2, 5], [5, 8], [2, 9], [2, 11], [2, 6]])
        elif "exp57"==body_name:
            #L1
            # better init
            self.body=RegSegBody([[1], [1, 3], [1, 3], [1, 3], [2, 5], [2, 8], [2, 5], [2, 7], [2, 6], [2, 11], [2, 15], [4, 9], [5, 16]])
        elif "exp58"==body_name:
            # 2 stages at os=16
            # L3
            self.body=RegSegBody4([[1, 3], [1, 3], [1, 4], [1, 2], [1, 4], [3, 8], [2, 14], [3, 6], [3, 9], [5, 9], [8, 14], [2, 20]])
        elif "exp59"==body_name:
            # better init2
            self.body=RegSegBody([[1], [1, 3], [1, 3], [1, 4], [2, 6], [2, 9], [1, 5], [2, 8], [2, 6], [2, 12], [2, 16], [4, 10], [4, 17]])
        elif "exp60"==body_name:
            # L5
            self.body=RegSegBody([[1], [1, 3], [2, 3], [1, 2], [2, 3], [3], [3, 5], [2, 7], [3, 5], [2, 7], [4, 5], [5, 8], [3, 10]])
        elif "exp61"==body_name:
            # L6
            self.body=RegSegBody([[1], [3, 5], [2, 7], [2, 13], [5, 8], [6, 10], [3, 18], [2, 14], [3, 9], [3, 12], [5, 7], [8, 12], [16, 26]])
        elif "exp62"==body_name:
            # L6 768 res
            #[[5.0, 3.0], [8.0, 4.0], [2.0, 14.0], [7.0, 11.0], [3.0, 21.0], [5.0, 10.0], [2.0, 17.0], [4.0, 6.0], [3.0, 12.0], [7.0, 4.0], [13.0, 9.0], [29.0, 18.0]]
            self.body=RegSegBody([[1], [3, 5], [3, 5], [3, 13], [7, 9], [2, 16], [6, 8], [2, 13], [2, 9], [3, 4], [5, 9], [7, 11], [15, 25]])
        elif "exp63"==body_name:
            # L6 with 1000 epochs
            #[[4.0, 3.0], [3.0, 10.0], [6.0, 2.0], [3.0, 19.0], [5.0, 7.0], [10.0, 9.0], [12.0, 2.0], [7.0, 4.0], [2.0, 13.0], [5.0, 11.0], [7.0, 10.0], [16.0, 25.0]]
            self.body=RegSegBody([[1], [3, 6], [3, 12], [2, 8], [5, 10], [9, 20], [3, 16], [4, 27], [7, 14], [9, 14], [5, 9], [3, 11], [16, 29]])
        elif "exp64"==body_name:
            # d2=ceil(d2)
            # d1=floor(d1)
            # L1, 768_res, mapillary
            self.body=RegSegBody([[1], [1, 2], [1, 2], [1, 3], [2, 5], [2, 3], [2, 6], [2, 8], [2, 7], [2, 5], [2, 10], [4, 7], [2, 12]])
        elif "exp65"==body_name:
            # self.body = RegSegBody([[1], [1, 2], [1, 2], [1, 3], [2, 3], [2, 7], [2, 3], [2, 6],[2, 5], [2, 9], [2, 11], [4, 7], [5, 14]])
            self.body=RegSegBody([[1], [1, 2], [1, 2], [1, 4], [1, 5], [2, 3], [2, 8], [2, 5], [2, 9], [2, 6], [4, 12], [5, 9], [2, 15]])
        elif "exp66"==body_name:
            # same as exp64
            self.body=RegSegBody([[1], [1, 2], [1, 2], [1, 3], [2, 5], [2, 3], [2, 6], [2, 8], [2, 7], [2, 5], [2, 10], [4, 7], [2, 12]])
        elif "regnety600mf"==body_name:
            self.body=RegNetY600MF()
        elif "mobilenetv3"==body_name:
            self.stem=nn.Identity()
            self.body=MobileNetV3()
        elif "L1"==body_name:
            # better init
            self.body=RegSegBodyLearnable(
                initial_dilation_rates=[[1,1]]*12,
                num_splits=2,
                transform="none"
            )
        elif "L2"==body_name:
            # os=32
            self.body=RegSegBodyLearnable2(
                initial_dilation_rates=[(1,1)]*12,
                num_splits=2
            )
        elif "L3"==body_name:
            # 2 stage at os=16
            # better init
            self.body=RegSegBodyLearnable3(
                initial_dilation_rates=[[1, 2], [1, 2], [1, 2], [1, 2], [2, 3], [4, 6], [2, 12], [3, 6], [3, 8], [5, 8], [8, 11], [2, 15]],
                num_splits=2
            )
        elif "L4"==body_name:
            # num_splits=4
            self.body=RegSegBodyLearnable(
                initial_dilation_rates=[(1,1,1,1)]*12,
                num_splits=4,
                transform="none"
            )
        elif "L5"==body_name:
            # transform="steep"
            self.body=RegSegBodyLearnable(
                initial_dilation_rates=[[1,1]]*12,
                num_splits=2,
                transform="steep"
            )
        elif "L6"==body_name:
            # transform="round"
            self.body=RegSegBodyLearnable(
                initial_dilation_rates=[[1,1]]*12,
                num_splits=2,
                transform="round"
            )
        else:
            raise NotImplementedError()
        if "decoder4" ==decoder_name:
            self.decoder=Exp2_Decoder4(num_classes,self.body.channels())
        elif "decoder10" ==decoder_name:
            self.decoder=Exp2_Decoder10(num_classes,self.body.channels())
        elif "decoder12" ==decoder_name:
            self.decoder=Exp2_Decoder12(num_classes,self.body.channels())
        elif "decoder14"==decoder_name:
            self.decoder=Exp2_Decoder14(num_classes,self.body.channels())
        elif "decoder26"==decoder_name:
            self.decoder=Exp2_Decoder26(num_classes,self.body.channels(),aux)
        elif "decoder29"==decoder_name:
            self.decoder=Exp2_Decoder29(num_classes,self.body.channels())
        elif "decoder30"==decoder_name:
            # for output stride 32
            self.decoder=Exp2_Decoder30(num_classes,self.body.channels())
        elif "BisenetDecoder"==decoder_name:
            self.decoder=BiseNetDecoder(num_classes,self.body.channels())
        elif "SFNetDecoder"==decoder_name:
            self.decoder=SFNetDecoder(num_classes,self.body.channels())
        elif "FaPNDecoder"==decoder_name:
            self.decoder=FaPNDecoder(num_classes,self.body.channels())
        elif "lraspp"==decoder_name:
            self.decoder=LRASPP(num_classes,self.body.channels())
        else:
            raise NotImplementedError()
        if pretrained != "" and not ablate_decoder:
            dic = torch.load(pretrained, map_location='cpu')
            if type(dic)==dict and "model" in dic:
                dic=dic['model']
            if change_num_classes:
                current_model=self.state_dict()
                new_state_dict={}
                print("change_num_classes: True")
                for k in current_model:
                    if dic[k].size()==current_model[k].size():
                        new_state_dict[k]=dic[k]
                    else:
                        print(k)
                        new_state_dict[k]=current_model[k]
                self.load_state_dict(new_state_dict,strict=True)
            else:
                self.load_state_dict(dic,strict=True)
    def forward(self,x):
        input_shape=x.shape[-2:]
        x=self.stem(x)
        x=self.body(x)
        x=self.decoder(x)
        if not self.training:
            if self.downsample:
                x=torch.softmax(x,dim=1)
            return F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        if self.downsample:
            return x
        elif isinstance(x,tuple):
            x,aux_x=x
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            aux_x = F.interpolate(aux_x, size=input_shape, mode='bilinear', align_corners=False)
            return x,aux_x
        else:
            return F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)

def num_classes_speed_test():
    from losses import DownsampledCE
    # We find that the training speed highly correlates with the number of classes
    # while the eval speed does not depend much on the number of classes
    v=[10,20,30,40,50,60,70,80]
    model_dic={}
    for num_classes in v:
        model=RegSeg("exp48_decoder26",num_classes=num_classes)
        model_dic[num_classes]=model
        benchmark_train([model],8,768,True,num_classes)
    x=torch.randn(1,3,1024,2048)
    benchmark_eval(model_dic,x,True)
    print()
    model_dic={}
    loss_fun=DownsampledCE(255)
    for num_classes in v:
        model=RegSeg("exp48_decoder26",num_classes=num_classes,downsample=True)
        model_dic[num_classes]=model
        benchmark_train([model],8,768,True,num_classes,loss_fun=loss_fun)
    x=torch.randn(1,3,1024,2048)
    benchmark_eval(model_dic,x,True)

def downsample_ce_speed_test():
    # def downsample(x):
    #     x.requires_grad=True
    #     return x[:256,:512]
    # def identity(x):
    #     x.requires_grad=True
    #     return x
    from losses import DownsampledCE
    loss_fun=DownsampledCE(255)
    for c in [19,60]:
        downsample=nn.Sequential(
            nn.Conv2d(3,c,1),
            nn.AdaptiveAvgPool2d((256,256))
        )
        identity=nn.Sequential(
            nn.Conv2d(3,c,1),
            nn.AdaptiveAvgPool2d((256,256)),
            nn.Upsample((1024,1024),mode="bilinear",align_corners=False)
        )
        benchmark_train([downsample],8,(1024,1024),True,c,loss_fun=loss_fun)
        benchmark_train([identity],8,(1024,1024),True,c)
        benchmark_memory([downsample],8,(1024,1024),True,c,loss_fun=loss_fun)
        benchmark_memory([identity],8,(1024,1024),True,c)

def dilation_speed_test():
    group_width=16
    w=256
    x=torch.randn(1,256,64,128)
    model_dic={}
    for d in range(1,19):
        model=nn.Conv2d(w,w,3,1,padding=d,dilation=d,groups=w//group_width,bias=False)
        model_dic[d]=model
    benchmark_eval(model_dic,x,True)

def block_speed_test():
    print("block speed test")
    model_dic={
        "YBlock":DBlock(256,256,[1],16,1,"se"),
        "DBlock(1,1)":DBlock(256,256,[1,1],16,1,"se"),
        "DBlock(1,4)":DBlock(256,256,[1,4],16,1,"se"),
        "DBlock(1,10)":DBlock(256,256,[1,10],16,1,"se")
    }
    x=torch.randn(1,256,64,128) # 1/16 original resolution
    ts=benchmark_eval(model_dic,x,True)
    print(ts)

def calculate_params(model):
    #https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/6
    import numpy as np
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    model_parameters = model.parameters()
    params2 = sum([np.prod(p.size()) for p in model_parameters])
    return params,params2


def sanity_check():
    x=torch.randn(2,3,100,100)
    model1=RegSeg("exp48_decoder26",19)
    model2=RegSeg("learn2_decoder14",19)
    for name, p in model2.named_parameters():
        print(name)

# needs to modify
def model_speed_test():
    import timm
    model1=timm.create_model("regnetx_006")
    model2=timm.create_model("regnety_006")
    model3=timm.create_model("resnet18d")
    model4=nn.Sequential(
        ConvBnAct(3,32,3,2,1),
        RegSegBody3([[1]]*11)
    )
    model5=nn.Sequential(
        ConvBnAct(3,32,3,2,1),
        RegSegBody3([[1]]*13)
    )
    model6=RegSeg("exp56_decoder30",num_classes=19)
    model7=RegSeg("exp48_decoder26",num_classes=19)
    model8=RegSeg("exp58_decoder26",num_classes=19)
    from competitors_models.DDRNet_Reimplementation import get_ddrnet_23,get_ddrnet_23slim
    model9=get_ddrnet_23()
    model10=get_ddrnet_23slim()
    x=torch.randn(1,3,1024,1024*2)
    ts=[]
    ts.extend(benchmark_eval([model1,model2,model3,model4,model5,model6,model7,model8,model9,model10],x,True))
    print(ts)
    print()
    ts=[]
    ts.extend(benchmark_eval([model1,model2,model3,model4,model5,model6,model7,model8,model9,model10],x,False))
    print(ts)
def decode_L6():
    model=RegSeg(
        "L6_decoder26",
        pretrained="checkpoints/L6_decoder26_run2",
        num_classes=19)
    print(model.body.get_dilation_rates())
def resolution_speed_test():
    val_input_sizes = [32 * x for x in
                       [20,22,23,24,26, 28,30, 32, 36]]
    model=RegSeg("exp48_decoder26",num_classes=19)
    print(val_input_sizes)
    dic={}
    for w in val_input_sizes:
        x=torch.randn(1,3,w,w*2)
        t=benchmark_eval([model],x,True)
        dic[w]=t
    print(dic)

def calculate_activations():
    from fvcore.nn import FlopCountAnalysis, flop_count_table,ActivationCountAnalysis
    model1=RegSeg("exp48_decoder26",19).eval()
    from competitors_models.DDRNet_Reimplementation import get_ddrnet_23,get_ddrnet_23slim
    x=torch.randn(1,3,1024,2048)
    model2=get_ddrnet_23().eval()
    for model in [model1,model2]:
        acts = ActivationCountAnalysis(model, x)
        dic=acts.by_module()
        for k,v in dic.items():
            if k.count(".")<=2 and v != 0:
                print(k,v//1000000)
if __name__ == "__main__":
    calculate_activations()
