import numpy as np
import matplotlib.pyplot as plt

path = "2_out/"

organ_list = ['kidney_right','aorta','esophagus']

l_ce_score = np.zeros([len(organ_list),4,5])
l_sd_score = np.zeros([len(organ_list),4,5])
l_cce_score = np.zeros([len(organ_list),4,5])

for oi in range(len(organ_list)):

    organ = organ_list[oi]
    for ni in range(4):
        for fi in range(5):
            ce_training_res = np.loadtxt(path + organ + '_ce_fold' + str(fi) + '_noise' + str(ni) + '/training_d.txt')  
            ce_validation_res = np.loadtxt(path + organ + '_ce_fold' + str(fi) + '_noise' + str(ni) + '/validation_d.txt')  
            ce_score = ce_validation_res[np.argmax(ce_training_res) ]
            l_ce_score[oi,ni,fi] = ce_score

            sd_training_res = np.loadtxt(path + organ + '_sd_fold' + str(fi) + '_noise' + str(ni) + '/training_d.txt')  
            sd_validation_res = np.loadtxt(path + organ + '_sd_fold' + str(fi) + '_noise' + str(ni) + '/validation_d.txt')  
            sd_score = sd_validation_res[np.argmax(sd_training_res)]
            l_sd_score[oi,ni,fi] = sd_score

            cce_training_res = np.loadtxt(path + organ + '_ce_fold' + str(fi) + '_noise' + str(ni) + '/training_o.txt')  
            cce_validation_res = np.loadtxt(path + organ + '_ce_fold' + str(fi) + '_noise' + str(ni) + '/validation_o.txt')  
            cce_score = cce_validation_res[np.argmax(cce_training_res)]
            l_cce_score[oi,ni,fi] = cce_score

avg_ce_score = np.mean(l_ce_score,axis=2)
avg_sd_score = np.mean(l_sd_score,axis=2)
avg_cce_score = np.mean(l_cce_score,axis=2)

print('ce*******')
print(np.round(avg_ce_score,4))
print('sd*******')
print(np.round(avg_sd_score,4))
print('cce*******')
print(np.round(avg_cce_score,4))

np.savetxt("3_out/avg_ce_score.txt", avg_ce_score)
np.savetxt("3_out/avg_sd_score.txt", avg_sd_score)
np.savetxt("3_out/avg_cce_score.txt", avg_cce_score)

plt.subplot(1,3,1)
plt.plot(avg_ce_score[0,:],'r')
plt.plot(avg_sd_score[0,:],'b')
plt.plot(avg_cce_score[0,:],'g')
plt.legend(["$\mathrm{CE}^{(0)}$","$\mathrm{SD}^{(0)}$","$\mathrm{CE}^{(*)}$"])

plt.subplot(1,3,2)
plt.plot(avg_ce_score[1,:],'r')
plt.plot(avg_sd_score[1,:],'b')
plt.plot(avg_cce_score[1,:],'g')
plt.legend(["$\mathrm{CE}^{(0)}$","$\mathrm{SD}^{(0)}$","$\mathrm{CE}^{(*)}$"])

plt.subplot(1,3,3)
plt.plot(avg_ce_score[2,:],'r')
plt.plot(avg_sd_score[2,:],'b')
plt.plot(avg_cce_score[2,:],'g')
plt.legend(["$\mathrm{CE}^{(0)}$","$\mathrm{SD}^{(0)}$","$\mathrm{CE}^{(*)}$"])

plt.savefig('3_out/fig.png', bbox_inches='tight')