from common import *
import os
import numpy as np
import nibabel as nib
import scipy.ndimage
import pickle

# Path to total segmentor data, set your own
ts_path = "../data/ts"

# Cases that are used for the experiments
data = {
    "aorta" : ["0004","0011","0013","0019","0024","0028","0029","0030","0032","0037","0038","0040","0042","0046","0048","0049","0050","0053","0057","0058","0065","0070","0071","0072","0076","0080","0082","0086","0089","0091","0095","0108","0109","0111","0114","0115","0120","0123","0124","0133","0138","0141","0145","0146","0150","0151","0161","0162","0163","0171","0175","0182","0183","0184","0189","0196","0197","0206","0223","0224","0228","0230","0231","0236","0239","0240","0241","0244","0248","0253","0255","0260","0271","0275","0285","0287","0291","0298","0300","0303","0307","0310","0315","0319","0320","0322","0327","0329","0331","0332","0334","0339","0342","0344","0345","0350","0354","0357","0358","0362","0363","0364","0365","0367","0369","0370","0371","0372","0374","0375","0377","0378","0382","0383","0390","0392","0393","0394","0398","0402","0403","0405","0407","0408","0412","0413","0414","0416","0420","0421","0422","0423","0425","0428","0429","0430","0431","0436","0437","0438","0439","0440","0441","0444","0446","0447","0452","0455","0456","0457","0459","0461","0462","0463","0467","0469","0470","0472","0476","0477","0481","0482","0483","0484","0485","0487","0488","0494","0498","0499","0500","0502","0507","0508","0510","0511","0515","0516","0517","0519","0522","0523","0529","0531","0536","0541","0542","0545","0546","0548","0549","0550","0551","0553","0555","0556","0561","0565","0571","0574","0578","0580","0582","0583","0585","0587","0589","0590","0591","0592","0593","0598","0600","0604","0606","0607","0612","0613","0616","0617","0618","0619","0621","0623","0624","0625","0626","0628","0629","0632","0635","0636","0637","0639","0641","0644","0645","0649","0650","0651","0654","0656","0657","0658","0659","0661","0662","0663","0664","0667","0669","0670","0672","0674","0675","0680","0681","0686","0687","0688","0690","0691","0692","0696","0697","0702","0703","0704","0707","0708","0709","0711","0714","0716","0717","0718","0719","0720","0721","0723","0726","0727","0729","0731","0732","0733","0735","0737","0739","0743","0747","0750","0751","0753","0754","0759","0763","0764","0765","0770","0771","0773","0775","0777","0778","0780","0784","0785","0788","0789","0790","0793","0794","0795","0796","0797","0798","0801","0802","0804","0805","0807","0808","0810","0811","0812","0816","0821","0822","0829","0830","0833","0834","0835","0836","0842","0844","0847","0859","0860","0863","0866","0869","0871","0878","0880","0884","0885","0889","0892","0894","0896","0897","0898","0899","0903","0904","0908","0912","0913","0915","0916","0918","0919","0923","0924","0927","0928","0930","0931","0933","0934","0939","0940","0941","0945","0946","0950","0954","0957","0959","0961","0963","0970","0973","0975","0977","0978","0982","0983","0985","0991","0992","0994","0999","1006","1007","1008","1012","1013"],
    "esophagus" : ["0004","0011","0019","0021","0024","0028","0032","0037","0040","0046","0049","0050","0065","0070","0071","0076","0080","0086","0091","0111","0114","0123","0124","0138","0141","0146","0163","0171","0175","0189","0197","0213","0224","0230","0231","0239","0241","0271","0277","0287","0303","0310","0322","0327","0331","0332","0334","0345","0349","0350","0354","0357","0358","0362","0363","0364","0369","0370","0371","0372","0373","0374","0375","0377","0393","0394","0398","0402","0405","0407","0408","0412","0413","0414","0420","0421","0422","0423","0425","0429","0430","0431","0433","0436","0437","0439","0440","0444","0446","0447","0452","0456","0457","0459","0461","0462","0463","0467","0469","0470","0472","0476","0481","0482","0483","0484","0485","0487","0488","0490","0494","0498","0499","0502","0507","0508","0510","0511","0515","0516","0517","0519","0522","0530","0536","0546","0549","0550","0551","0553","0556","0557","0561","0565","0571","0573","0574","0578","0580","0582","0583","0585","0587","0589","0590","0591","0592","0593","0598","0602","0604","0606","0612","0616","0617","0618","0621","0623","0624","0625","0628","0629","0632","0635","0636","0637","0639","0641","0645","0649","0650","0651","0656","0657","0658","0659","0662","0663","0664","0667","0669","0672","0674","0675","0680","0687","0690","0691","0692","0695","0696","0697","0698","0702","0703","0704","0708","0709","0711","0714","0716","0717","0718","0720","0723","0726","0727","0731","0733","0737","0739","0743","0747","0750","0751","0752","0754","0759","0763","0764","0765","0771","0775","0777","0778","0784","0787","0788","0789","0790","0793","0794","0795","0796","0797","0798","0801","0802","0804","0807","0810","0811","0812","0816","0821","0824","0829","0830","0833","0834","0835","0842","0847","0859","0860","0866","0869","0878","0880","0884","0885","0889","0890","0892","0894","0896","0897","0898","0899","0903","0904","0910","0912","0913","0914","0915","0916","0919","0923","0928","0930","0933","0934","0939","0940","0941","0945","0946","0950","0954","0957","0959","0963","0970","0975","0976","0977","0978","0982","0983","0991","0992","0994","0999","1004","1006","1007","1012","1013","1016","1018","1022","1023","1024","1028","1029","1031","1038","1042","1045","1046","1052","1058","1060","1061","1062","1066","1067","1069","1070","1073","1077","1082","1083","1085","1086","1088","1089","1090","1094","1099","1100","1101","1102","1105","1110","1111","1120","1123","1124","1127","1133","1135","1140","1142","1143","1145","1152","1155","1157","1159","1161","1162","1164","1167","1171","1172","1173","1174","1176","1178","1189","1190","1191","1206","1207","1208","1209","1221","1223","1225","1228","1230","1233","1235","1238","1240","1245","1247","1248","1249","1255","1262","1264","1267","1268","1269","1273","1276","1277"],
    "kidney_right": ["0010","0011","0013","0014","0015","0016","0019","0024","0028","0029","0030","0031","0038","0040","0042","0045","0050","0052","0053","0054","0058","0065","0072","0073","0076","0077","0078","0082","0086","0089","0091","0096","0108","0109","0117","0119","0120","0133","0137","0139","0141","0143","0150","0151","0153","0157","0158","0161","0163","0168","0171","0178","0182","0183","0189","0193","0194","0196","0201","0204","0210","0212","0218","0220","0224","0227","0236","0239","0243","0244","0248","0250","0252","0255","0257","0260","0264","0275","0286","0287","0291","0293","0300","0307","0311","0314","0319","0320","0321","0325","0328","0329","0332","0334","0339","0341","0342","0343","0344","0345","0350","0355","0358","0362","0369","0370","0375","0383","0390","0392","0401","0402","0403","0406","0408","0423","0428","0429","0436","0440","0441","0446","0447","0456","0458","0461","0467","0471","0472","0473","0476","0477","0480","0483","0484","0494","0495","0499","0500","0502","0505","0507","0509","0513","0516","0519","0529","0536","0541","0542","0543","0545","0546","0549","0550","0551","0553","0571","0574","0577","0578","0583","0585","0586","0587","0589","0590","0591","0592","0593","0602","0603","0612","0613","0616","0617","0620","0621","0623","0625","0626","0628","0629","0635","0636","0637","0639","0644","0648","0649","0650","0651","0656","0657","0661","0662","0663","0664","0667","0669","0670","0680","0682","0683","0686","0687","0692","0694","0699","0702","0703","0705","0708","0720","0721","0726","0727","0731","0733","0739","0746","0749","0751","0754","0760","0762","0763","0764","0765","0777","0778","0788","0790","0794","0796","0797","0804","0806","0807","0812","0830","0835","0836","0842","0859","0863","0869","0878","0880","0884","0885","0894","0896","0899","0903","0904","0912","0913","0916","0918","0919","0923","0924","0927","0928","0933","0934","0939","0940","0945","0950","0957","0959","0961","0965","0970","0979","0982","0983","0985","0992","0994","1006","1008","1012","1016","1024","1029","1031","1037","1038","1044","1061","1062","1063","1069","1070","1082","1085","1086","1088","1089","1090","1099","1102","1105","1111","1120","1121","1123","1124","1127","1130","1131","1135","1143","1145","1149","1152","1153","1159","1161","1171","1174","1176","1183","1187","1189","1206","1207","1208","1209","1210","1212","1216","1224","1228","1230","1233","1238","1240","1244","1247","1248","1249","1264","1267","1273","1276","1283","1287","1291","1293","1294","1297","1304","1307","1309","1314","1319","1321","1322","1332","1336","1340","1344","1347","1348","1349","1350","1354","1361","1362","1363","1364","1365","1366","1367","1368","1369","1371","1372","1373","1374","1377","1379","1380","1382","1383","1384","1386","1387","1388","1390","1391","1394","1395","1397","1400","1403"]
 }

def main(organ_str):
    os.mkdir("1_out/organ/"+organ_str)
    hsl=32
    np.random.seed(0)
    for i_str in data[organ_str]:
        print(i_str)

        ct_nii = nib.load(ts_path + '/s' + i_str + '/ct.nii.gz')
        organ_nii = nib.load(ts_path + '/s' + i_str + '/segmentations/' + organ_str + '.nii.gz')

        ct = ct_nii.get_fdata()
        organ = organ_nii.get_fdata()

        ct = scipy.ndimage.zoom(ct,0.5)
        organ = np.round(scipy.ndimage.zoom(organ,0.5))

        xind,yind,zind = np.nonzero(organ)

        xmin=xind[0]
        xmax=xind[-1]
        xcenter = (xmax+xmin)/2
        ymin=yind[0]
        ymax=yind[-1]
        ycenter = (ymax+ymin)/2
        zmin=zind[0]
        zmax=zind[-1]
        zcenter = (zmax+zmin)/2

        newxmin = int(xcenter-hsl)
        newxmax = newxmin+2*hsl
        newymin = int(ycenter-hsl)
        newymax = newymin+2*hsl
        newzmin = int(zcenter-hsl)
        newzmax = newzmin+2*hsl

        bbct = np.transpose(ct[newxmin:newxmax,newymin:newymax,newzmin:newzmax])
        bbct[:,::-1,::-1] = bbct

        bborgan = np.transpose(organ[newxmin:newxmax,newymin:newymax,newzmin:newzmax])
        bborgan[:,::-1,::-1] = bborgan

        bborgan_noisy_1 = get_deformation_sample(bborgan,[0.01,0.01,0.01])
        bborgan_noisy_2 = get_deformation_sample(bborgan,[0.02,0.02,0.02])
        bborgan_noisy_3 = get_deformation_sample(bborgan,[0.03,0.03,0.03])

        bborgan_marginal_1 = get_deformation_marginals(bborgan,[0.01,0.01,0.01])
        bborgan_marginal_2 = get_deformation_marginals(bborgan,[0.02,0.02,0.02])
        bborgan_marginal_3 = get_deformation_marginals(bborgan,[0.03,0.03,0.03])

        bborgan_aopt_1 = gen_opt_acc_seg(get_deformation_marginals(bborgan,[0.01,0.01,0.01]))
        bborgan_aopt_2 = gen_opt_acc_seg(get_deformation_marginals(bborgan,[0.02,0.02,0.02]))
        bborgan_aopt_3 = gen_opt_acc_seg(get_deformation_marginals(bborgan,[0.03,0.03,0.03]))

        bborgan_dopt_1 = gen_opt_dice_seg(get_deformation_marginals(bborgan,[0.01,0.01,0.01]))
        bborgan_dopt_2 = gen_opt_dice_seg(get_deformation_marginals(bborgan,[0.02,0.02,0.02]))
        bborgan_dopt_3 = gen_opt_dice_seg(get_deformation_marginals(bborgan,[0.03,0.03,0.03]))

        write_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_x', bbct, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y', bborgan, [0.15,0.15,0.15], [0,0,0])

        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y1', bborgan_noisy_1, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y2', bborgan_noisy_2, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y3', bborgan_noisy_3, [0.15,0.15,0.15], [0,0,0])

        write_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y1m', bborgan_marginal_1, [0.15,0.15,0.15], [0,0,0])
        write_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y2m', bborgan_marginal_2, [0.15,0.15,0.15], [0,0,0])
        write_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y3m', bborgan_marginal_3, [0.15,0.15,0.15], [0,0,0])

        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y1a', bborgan_aopt_1, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y2a', bborgan_aopt_2, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y3a', bborgan_aopt_3, [0.15,0.15,0.15], [0,0,0])

        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y1d', bborgan_dopt_1, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y2d', bborgan_dopt_2, [0.15,0.15,0.15], [0,0,0])
        write_seg_nrrd(r'1_out/organ/' + organ_str + r'/', i_str + '_y3d', bborgan_dopt_3, [0.15,0.15,0.15], [0,0,0])


    # Create data splits...
    fold1 = data[organ_str][0::5]
    fold2 = data[organ_str][1::5]
    fold3 = data[organ_str][2::5]
    fold4 = data[organ_str][3::5]
    fold5 = data[organ_str][4::5]

    folds = {}
    folds[0] = {}
    folds[0]['training'] = fold2+fold3+fold4+fold5
    folds[0]['validation'] = fold1

    folds[1] = {}
    folds[1]['training'] = fold1+fold3+fold4+fold5
    folds[1]['validation'] =  fold2

    folds[2] = {}
    folds[2]['training'] = fold1+fold2+fold4+fold5
    folds[2]['validation'] = fold3

    folds[3] = {}
    folds[3]['training'] = fold1+fold2+fold3+fold5
    folds[3]['validation'] = fold4 

    folds[4] = {}
    folds[4]['training'] = fold1+fold2+fold3+fold4
    folds[4]['validation'] = fold5

    with open('1_out/splits/' + organ_str + '_folds_400.pkl', 'wb') as fp:
        pickle.dump(folds,fp)


os.mkdir("1_out/organ")
os.mkdir("1_out/splits")

main("aorta")
main("esophagus")
main("kidney_right")