from matplotlib import pyplot as plt
import numpy as np

def get_dilated_kernel(d):
    x=np.array([1]+[0]*(d-1)+[2]+[0]*(d-1)+[1])
    return x/np.sum(x)
def chaining_dilations(ds):
    x=get_dilated_kernel(ds[0])
    for d in ds[1:]:
        v=get_dilated_kernel(d)
        x=np.convolve(x,v)
    return x
def dilation_update(fs,d):
    fs=[[k,s] for k,s in fs]
    for i in range(len(d)):
        fs[i][0]=fs[i][0]+2*(d[i]-1)
    return fs
def field_of_vision(fs):
    k=1
    s=1
    for _k,_s in fs: # kernel size and stride
        k=k+(_k-1)*s
        s=s*_s
    return k,s
def print_fov(fs,d):
    fs=dilation_update(fs,d)
    k,s=field_of_vision(fs)
    # print(fs)
    print(k,s)

def get_max_dilation(k,s):
    return (k+1)//2//s
def clip_to_integer(ds):
    import math
    new_ds=[]
    for d1,d2 in ds:
        if d1>d2:
            d1,d2=d2,d1
        d2=math.ceil(d2)
        d1=math.floor(d1)
        if d1==d2:
            new_ds.append([d1])
        else:
            new_ds.append([d1,d2])
    return new_ds
def clip_to_integer2(ds):
    import math
    new_ds=[]
    for d1,d2 in ds:
        if d1>d2:
            d1,d2=d2,d1
        d2=round(d2)
        d1=round(d1)
        if d1==d2:
            new_ds.append([d1])
        else:
            new_ds.append([d1,d2])
    return new_ds
def staged_net(ds):
    fs=[]
    for d in ds:
        for i in range(d):
            if i==0:
                fs.append([3,2])
            else:
                fs.append([3,1])
    return fs
def xception71():
    stages=[3,6,51,4]
    fs=staged_net(stages)
    fs=[(3,2),(3,1),(3,1),(3,1)]+fs+[(19,1)]
    print(fs)
    k,s=field_of_vision(fs)
    print(k,s)
def WideResNet41():
    stages=[7,5,12,6,5]
    fs=staged_net(stages)
    fs=fs+[(19,1)]
    print(fs)
    k,s=field_of_vision(fs)
    print(k,s)
def SWideRNet():
    stages=[7,5,12,6,2]
    fs=staged_net(stages)
    fs=fs+[(7,1)]*3#+[(19,1)]
    print(fs)
    k,s=field_of_vision(fs)
    print(k,s)
def decode_learn_helper(ds,func=clip_to_integer):
    new_ds=func(ds)
    print(new_ds)
    fs=[(3, 2), [3, 2], [3, 2], [3, 1], [3, 1], [3, 2]]+[[3,1]]*13
    d=[1,1,1,1,1,1,1]+[max(r) for r in new_ds]
    print_fov(fs,d)
def get_fov(ds,func=clip_to_integer):
    new_ds=func(ds)
    fs=[(3, 2), [3, 2], [3, 2], [3, 1], [3, 1], [3, 2]]+[[3,1]]*13
    d=[1,1,1,1,1,1]+[max(r) for r in new_ds]  # TODO add a 1
    fs=dilation_update(fs,d)
    k,s=field_of_vision(fs)
    return k
def decode_learn_helper2(ds,func=clip_to_integer):
    new_ds=func(ds)
    print(new_ds)
    fs=[(3, 2), [3, 2], [3, 2], [3, 1], [3, 1], [3, 2]]+[[3,1]]*4+[[3,2]]+[[3,1]]*8
    d=[1,1,1,1,1,1,1]+[max(r) for r in new_ds]
    print_fov(fs,d)
def decode_learn():
    # L1_decoder26_bootstrapped
    # L1_decoder26
    # L3_decoder26
    # L1_decoder26 better init
    learn2=[[1, 2], [1, 2], [1, 3], [2, 3], [2, 7], [2, 3], [2, 6], [2, 5], [2, 9], [2, 11], [4, 7], [5, 14]]
    decode_learn_helper(learn2)
    L1_768_res=[[1.39, 1.47], [1.35, 1.33], [1.22, 2.52], [4.5, 2.45], [2.46, 2.47], [5.5, 2.47], [2.39, 7.57], [6.51, 2.37], [2.37, 4.46], [2.36, 9.53], [6.5, 4.43], [2.35, 11.49]]
    decode_learn_helper(L1_768_res)
    L1_decoder26=[[1.42, 1.52], [1.36, 1.4], [1.29, 3.48], [1.21, 4.52], [2.44, 2.45], [7.5, 2.44], [2.41, 4.5], [8.52, 2.41], [2.36, 5.49], [4.41, 11.54], [5.44, 8.47], [2.36, 14.51]]
    decode_learn_helper(L1_decoder26)


    L1_decoder26_bootstrapped=[[1.25, 2.55], [1.23, 2.46], [1.2, 2.5], [2.51, 4.59], [2.5, 7.8], [2.49, 4.63], [2.55, 6.79], [2.38, 5.7], [2.41, 10.76], [2.35, 14.73], [4.45, 8.72], [5.6, 15.68]]
    decode_learn_helper(L1_decoder26_bootstrapped)

    # L2
    # ds=[[1.35, 2.55], [1.26, 1.25], [1.64, 1.23], [1.36, 1.3], [1.61, 1.55], [1.42, 1.37], [4.73, 3.57], [4.57, 2.68], [5.81, 7.5], [8.75, 2.39], [2.65, 10.75], [2.49, 5.65]]
    # decode_learn_helper2(ds)

    # L3
    L3=[[1.24, 2.5], [1.13, 2.41], [1.08, 3.81], [1.44, 1.52], [1.42, 3.68], [3.74, 7.55], [2.39, 13.73], [3.55, 5.7], [3.58, 8.73], [5.6, 8.74], [8.63, 13.78], [2.36, 19.71]]
    decode_learn_helper(L3)

    # L5
    L5=[[1.38, 2.64], [2.35, 2.55], [1.41, 2.42], [2.57, 2.42], [2.56, 2.57], [2.61, 4.63], [1.6, 6.69], [4.6, 2.68], [2.39, 6.67], [4.58, 4.44], [4.59, 7.72], [2.63, 9.6]]
    decode_learn_helper(L5,clip_to_integer2)

def g():
    # x=chaining_dilations([1,2,4,8,8,8,16*1,16*2,16*4,16*6,16*8,16*10]+[16*12]*7)
    # print(x)
    # plt.plot(x, label="1")
    ds=[1,5,7,12,7]
    x=chaining_dilations([1,2,4,8,8,8]+[d*16 for d in ds])
    print(x)
    plt.plot(x, label="2")

    plt.legend(loc='lower right')
    plt.show()
def f():
    fs=[(3, 2), [3, 2], [3, 2], [3, 1], [3, 1], [3, 2]]+[[3,1]]*13
    d=[1,1,1,1,1,1,1]+[1]*11+[2]*1 # 511
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*9+[4]*3 # 1055
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*2+[4]*10 # 1503
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*2+[4]*2+[6]*8 # 2015
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*1+[4]*1+[6]*5+[8]*5 # 2527
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*1+[4]*1+[6]*1+[8]*5+[10]*4 # 3039
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*1+[4]*1+[6]*1+[8]*1+[10]*4+[12]*4 # 3551
    print_fov(fs,d)
    d=[1,1,1,1,1,1,1]+[2]*1+[4]*1+[6]*1+[8]*1+[10]*1+[12]*2+[14]*5 # 4063
    print_fov(fs,d)
def test1():
    regnets=[
        [1,3,7,4],
    ]
    for ds in regnets:
        fs=staged_net(ds)
        fs=[(3,2)]+fs
        k,s=field_of_vision(fs)
        print(k,s)
if __name__=="__main__":
    decode_learn()
    # test1()
    # g()
