from common import *
import numpy as np
import matplotlib.pyplot as plt

def get_deformation_sample_average(m,a,n):
    mest = np.zeros(m.shape)
    for i in range(n):
        mest += get_deformation_sample(m,a)
    mest /= n
    return mest

def saveplot(name,temp):
    plt.imshow(temp, vmin=0, vmax=1, aspect='equal')
    plt.tight_layout()
    plt.axis('off')
    plt.savefig('0_out/'+ name + '.pdf', bbox_inches='tight', pad_inches=0.0, transparent=True, dpi=200)
    plt.close()

N = 500
x = np.linspace(0,1,N)
m = np.zeros([N,N])
for i in range(N):
    for j in range(N):
        if x[i] > 0.4 and x[i] < 0.6 and x[j] > 0.4 and x[j] < 0.6:
            m[i,j] = 1

        if x[i] > 0.49 and x[i] < 0.51 and x[j] > 0.5 and x[j] < 0.8:
            m[i,j] = 1

        if x[i] > 0.48 and x[i] < 0.52 and x[j] > 0.2 and x[j] < 0.5:
            m[i,j] = 1

        if x[i] < 0.52 and x[i] > 0.48 and x[j] < 0.52 and x[j] > 0.48:
            m[i,j] = 0

m1=m
m2 = get_deformation_marginals(m,0.01)
m3 = get_deformation_marginals(m,0.02)
m4 = get_deformation_marginals(m,0.03)
m5 = get_deformation_marginals(m,0.04)
m6 = get_deformation_marginals(m,0.05)
m7 = get_deformation_marginals(m,0.06)

a1 = gen_opt_acc_seg(m1)
a2 = gen_opt_acc_seg(m2)
a3 = gen_opt_acc_seg(m3)
a4 = gen_opt_acc_seg(m4)
a5 = gen_opt_acc_seg(m5)
a6 = gen_opt_acc_seg(m6)
a7 = gen_opt_acc_seg(m7)

d1 = gen_opt_dice_seg(m1)
d2 = gen_opt_dice_seg(m2)
d3 = gen_opt_dice_seg(m3)
d4 = gen_opt_dice_seg(m4)
d5 = gen_opt_dice_seg(m5)
d6 = gen_opt_dice_seg(m6)
d7 = gen_opt_dice_seg(m7)

s1=m
s2 = get_deformation_sample_average(m,0.01,5)
s3 = get_deformation_sample_average(m,0.02,5)
s4 = get_deformation_sample_average(m,0.03,5)
s5 = get_deformation_sample_average(m,0.04,5)
s6 = get_deformation_sample_average(m,0.05,5)
s7 = get_deformation_sample_average(m,0.06,5)


saveplot('m1',m1)
saveplot('m2',m2)
saveplot('m3',m3)
saveplot('m4',m4)
saveplot('m5',m5)
saveplot('m6',m6)
saveplot('m7',m7)

saveplot('a1',a1)
saveplot('a2',a2)
saveplot('a3',a3)
saveplot('a4',a4)
saveplot('a5',a5)
saveplot('a6',a6)
saveplot('a7',a7)

saveplot('d1',d1)
saveplot('d2',d2)
saveplot('d3',d3)
saveplot('d4',d4)
saveplot('d5',d5)
saveplot('d6',d6)
saveplot('d7',d7)

saveplot('s1',s1)
saveplot('s2',s2)
saveplot('s3',s3)
saveplot('s4',s4)
saveplot('s5',s5)
saveplot('s6',s6)
saveplot('s7',s7)