from common import *
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy.ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
import nrrd
import shutil
import pickle

class UNet(nn.Module):
    def __init__(self, lc=[32,64,128,256]):
        super(UNet, self).__init__()
        self.lc = lc

        trackrs=False

        self.first = nn.Sequential(
                nn.Conv3d(1,self.lc[0],3,padding='same'),
                nn.InstanceNorm3d(self.lc[0],track_running_stats=trackrs),
                nn.ReLU(),
                nn.Conv3d(self.lc[0],self.lc[0],3,padding='same'),
                nn.InstanceNorm3d(self.lc[0],track_running_stats=trackrs),
                nn.ReLU()
                )

        self.last = nn.Sequential(
                nn.Conv3d(self.lc[0],self.lc[0],3,padding='same'),
                nn.InstanceNorm3d(self.lc[0],track_running_stats=trackrs),
                nn.ReLU(),
                nn.Conv3d(self.lc[0],self.lc[0],3,padding='same'),
                nn.InstanceNorm3d(self.lc[0],track_running_stats=trackrs),
                nn.ReLU(),
                nn.Conv3d(self.lc[0],1,3,padding='same'))


        self.block_list_down = nn.ModuleList([
                nn.Sequential(
                nn.Conv3d(self.lc[i+1],self.lc[i+1],3,padding='same'),
                nn.InstanceNorm3d(self.lc[i+1],track_running_stats=trackrs),
                nn.ReLU(),
                nn.Conv3d(self.lc[i+1],self.lc[i+1],3,padding='same'),
                nn.InstanceNorm3d(self.lc[i+1],track_running_stats=trackrs),
                nn.ReLU(),
                )
                for i in range(len(self.lc)-1)])
        self.block_list_up = nn.ModuleList([
                nn.Sequential(
                nn.Conv3d(self.lc[i],self.lc[i],3,padding='same'),
                nn.InstanceNorm3d(self.lc[i],track_running_stats=trackrs),
                nn.ReLU(),
                nn.Conv3d(self.lc[i],self.lc[i],3,padding='same'),
                nn.InstanceNorm3d(self.lc[i],track_running_stats=trackrs),
                nn.ReLU()
                )
                for i in range(len(self.lc)-1)[::-1]])
        self.down_list = nn.ModuleList([nn.Conv3d(self.lc[i],self.lc[i+1],2,2) for i in range(len(self.lc)-1)])
        self.up_list = nn.ModuleList([nn.ConvTranspose3d(self.lc[i+1],self.lc[i],2,2) for i in range(len(self.lc)-1)[::-1]])
        self.combine_list = nn.ModuleList([nn.Conv3d(2*self.lc[i],self.lc[i],3,padding='same') for i in range(len(self.lc)-1)[::-1]])

    def forward(self, x):
        x = self.first(x)
        xlist = []
        for i in range(len(self.lc)-1):
            xlist.append(x)
            x = self.down_list[i](x)
            x = self.block_list_down[i](x)
        for i in range(len(self.lc)-1):
            x = self.up_list[i](x)
            x = self.combine_list[i](torch.cat((xlist[-i-1],x),dim=1))
            x = self.block_list_up[i](x)
        x = self.last(x)
        return x

def get_opt_thresh(m):
    psi = torch.sort(torch.flatten(m),descending=True)[0]
    d = 2*torch.cumsum(psi,dim=0)/(torch.sum(m)+torch.arange(1,len(psi)+1).to('cuda')) 
    t = torch.max(d)/2
    return t

def get_int_thresh(m,target):
    y_pred = torch.flatten(m)
    y_target = torch.flatten(target)

    y_pred_sorted, indices = torch.sort(y_pred, descending=True)
    y_target_sorted = y_target[indices]

    d = 2*torch.cumsum(y_target_sorted,dim=0)/(torch.sum(y_target_sorted)+torch.arange(1,y_target_sorted.shape[0]+1).to('cuda')) 
    max_index = torch.argmax(d)

    return y_pred_sorted[max_index]


def get_a(out,target):
    b0 = (torch.sigmoid(out)>=0.5).float()
    b1 = target.float()
    a = (b0*b1 + (1-b0)*(1-b1)).mean()
    return a

def get_d(out,target):
    b0 = (torch.sigmoid(out)>=0.5).float()
    b1 = target.float()
    if b1.mean().item()>0:
        d = 2*(b0*b1).mean()/(b0.mean()+b1.mean())
    else:
        d=torch.tensor([1.])
    return d

def get_o(out,target):
    pred = torch.sigmoid(out)
    t = get_opt_thresh(pred)
    b0 = (pred>=t).float()
    b1 = target
    if b1.mean().item()>0:
        o = 2*(b0*b1).mean()/(b0.mean()+b1.mean())
    else:
        o=torch.tensor([1.])
    return o

def get_i(out,target):
    pred = torch.sigmoid(out)
    thresh = get_int_thresh(pred, target)
    b0 = (pred >= thresh).float()
    b1 = target
    if b1.mean().item()>0:
        d = 2*torch.mean(b0*b1)/(torch.mean(b0)+torch.mean(b1))
    else:
        d=torch.tensor([1.])
    return d

def aug_transform(image,target,ground_target,corrected_target):
    max_shift = 0.1 # Maximum shift in each direction
    max_angle = 30 # Maximum rotation angle in each direction
    max_smooth = 0.04

    if np.random.rand() < 0.5:
        image_ = image[0,:,:,:]
        target_ = target[0,:,:,:]
        ground_target_ = ground_target[0,:,:,:]
        corrected_target_ = corrected_target[0,:,:,:]

        r=np.round(2*max_shift*np.asarray(image_.shape)*(np.random.random(3)-0.5))
        image_ = scipy.ndimage.shift(image_, r,order=1)
        target_ = scipy.ndimage.shift(target_, r,order=1)
        ground_target_ = scipy.ndimage.shift(ground_target_, r,order=1)
        corrected_target_ = scipy.ndimage.shift(corrected_target_, r,order=1)

        r = max_angle*2*(np.random.random(3)-0.5)
        image_ = scipy.ndimage.rotate(image_,r[0], reshape=False,axes=(1,0),order=1)
        image_ = scipy.ndimage.rotate(image_,r[1], reshape=False,axes=(2,0),order=1)
        image_ = scipy.ndimage.rotate(image_,r[2], reshape=False,axes=(2,1),order=1)
        target_ = scipy.ndimage.rotate(target_,r[0], reshape=False,axes=(1,0),order=1)
        target_ = scipy.ndimage.rotate(target_,r[1], reshape=False,axes=(2,0),order=1)
        target_ = scipy.ndimage.rotate(target_,r[2], reshape=False,axes=(2,1),order=1)
        ground_target_ = scipy.ndimage.rotate(ground_target_,r[0], reshape=False,axes=(1,0),order=1)
        ground_target_ = scipy.ndimage.rotate(ground_target_,r[1], reshape=False,axes=(2,0),order=1)
        ground_target_ = scipy.ndimage.rotate(ground_target_,r[2], reshape=False,axes=(2,1),order=1)
        corrected_target_ = scipy.ndimage.rotate(corrected_target_,r[0], reshape=False,axes=(1,0),order=1)
        corrected_target_ = scipy.ndimage.rotate(corrected_target_,r[1], reshape=False,axes=(2,0),order=1)
        corrected_target_ = scipy.ndimage.rotate(corrected_target_,r[2], reshape=False,axes=(2,1),order=1)

        target_ = np.round(target_)
        ground_target_ = np.round(ground_target_)
        corrected_target_ = np.round(corrected_target_)


        image_ = scipy.ndimage.gaussian_filter(image_,np.random.rand()*max_smooth*np.asarray(image_.shape), mode='constant')

        image[0,:,:,:] = image_
        target[0,:,:,:] = target_
        ground_target[0,:,:,:] = ground_target_
        corrected_target[0,:,:,:] = corrected_target_

    return image,target,ground_target,corrected_target

class SegmentationDataset(Dataset):
    def __init__(self, inputs, targets, ground_targets, corrected_targets, transform=None):
        self.inputs = inputs
        self.targets = targets
        self.ground_targets = ground_targets
        self.corrected_targets = corrected_targets
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        image,_ = nrrd.read(self.inputs[idx],index_order='C')
        image = np.expand_dims(image,axis=[0])

        target,_ = nrrd.read(self.targets[idx], index_order='C')
        target = np.squeeze(target)
        target = np.expand_dims(target,axis=[0])


        ground_target,_ = nrrd.read(self.ground_targets[idx], index_order='C')
        ground_target = np.squeeze(ground_target)
        ground_target = np.expand_dims(ground_target,axis=[0])

        corrected_target,_ = nrrd.read(self.corrected_targets[idx], index_order='C')
        corrected_target = np.squeeze(corrected_target)
        corrected_target = np.expand_dims(corrected_target,axis=[0])

        if self.transform:
            image,target,ground_target,corrected_target = self.transform(image,target,ground_target,corrected_target)

        image = (image-np.mean(image))/np.std(image)

        return torch.from_numpy(image).float(), torch.from_numpy(target).float(), torch.from_numpy(ground_target).float(), torch.from_numpy(corrected_target).float()


def softDice(out,target):
    pred = torch.sigmoid(out)
    sd = 1-2*torch.mean(pred*target)/(torch.mean(pred) + torch.mean(target))
    return sd


def run(fold, data_path, save_path, epochs, noise_level='', lossfunction='ce', corr = ''):

    if os.path.exists(save_path):
        print('Save path already exists')
        exit()
    os.mkdir(save_path)
    
    training_inputs = [data_path + name + '_x.nrrd' for name in fold['training']]
    training_targets = [data_path + name + '_y' + noise_level +'.seg.nrrd' for name in fold['training']]
    training_ground_targets = [data_path + name + '_y.seg.nrrd' for name in fold['training']]
    training_corrected_targets = [data_path + name + '_y' + noise_level + corr + '.seg.nrrd' for name in fold['training']]
    training_data = SegmentationDataset(inputs = training_inputs, targets = training_targets, ground_targets = training_ground_targets, corrected_targets = training_corrected_targets, transform=aug_transform)
    training_DataLoader = DataLoader(training_data,batch_size=1,shuffle=True,num_workers=4,persistent_workers=True)

    validation_inputs = [data_path + name + '_x.nrrd' for name in fold['validation']]
    validation_targets = [data_path + name + '_y' + noise_level +'.seg.nrrd' for name in fold['validation']]
    validation_ground_targets = [data_path + name + '_y.seg.nrrd' for name in fold['validation']]
    validation_corrected_targets = [data_path + name + '_y' + noise_level + corr +'.seg.nrrd' for name in fold['validation']]
    validation_data = SegmentationDataset(inputs = validation_inputs,targets = validation_targets, ground_targets = validation_ground_targets, corrected_targets = validation_corrected_targets)
    validation_DataLoader = DataLoader(validation_data, batch_size=1,shuffle=False,num_workers=4,persistent_workers=True)
    

    device = torch.device('cuda')

    model = UNet(lc=[32,64,128,256,512]).to(device)

    if lossfunction=='ce':
        criterion = F.binary_cross_entropy_with_logits
    elif lossfunction=='sd':
        criterion = softDice

    torch.manual_seed(0) # Make pytorch determenistic
    np.random.seed(0) # Make numpy determenistic
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-4,weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2,patience=30)
    lr_scheduler=scheduler
    save_freq = 10


    learning_rate = []
    training_l = []
    training_a = []
    training_d = []
    training_o = []
    training_i = []
    training_lc = []
    training_ac = []
    training_dc = []
    training_oc = []
    training_ic = []
    training_lg = []
    training_ag = []
    training_dg = []
    training_og = []
    training_ig = []
    validation_l = []
    validation_a = []
    validation_d = []
    validation_o = []
    validation_i = []
    validation_lc = []
    validation_ac = []
    validation_dc = []
    validation_oc = []
    validation_ic = []
    validation_lg = []
    validation_ag = []
    validation_dg = []
    validation_og = []
    validation_ig = []

    for epoch in range(epochs): 

        print('')
        print('********** Training **********')

        model.train()
        tmp_l = []
        tmp_a = []
        tmp_d = []
        tmp_o = []
        tmp_i = []
        tmp_lg = []
        tmp_ag = []
        tmp_dg = []
        tmp_og = []
        tmp_ig = []
        tmp_lc = []
        tmp_ac = []
        tmp_dc = []
        tmp_oc = []
        tmp_ic = []

        for i, (x, y, z1, z2) in enumerate(training_DataLoader):
            input, target, ground_target, corrected_target = x.to(device), y.to(device), z1.to(device), z2.to(device)
            optimizer.zero_grad()  
            out = model(input)  
            loss = criterion(out, target)

            tmp_l.append(loss.item())
            tmp_a.append(get_a(out,target).item())
            tmp_d.append(get_d(out,target).item())
            tmp_o.append(get_o(out,target).item())
            tmp_i.append(get_i(out,target).item())
            tmp_lg.append(criterion(out, ground_target).item())
            tmp_ag.append(get_a(out,ground_target).item())
            tmp_dg.append(get_d(out,ground_target).item())
            tmp_og.append(get_o(out,ground_target).item())
            tmp_ig.append(get_i(out,ground_target).item())
            tmp_lc.append(criterion(out, corrected_target).item())
            tmp_ac.append(get_a(out,corrected_target).item())
            tmp_dc.append(get_d(out,corrected_target).item())
            tmp_oc.append(get_o(out,corrected_target).item())
            tmp_ic.append(get_i(out,corrected_target).item())

            print( str(epoch) + '/' + str(i) + ' ' +
                   '(' + '{:2f}'.format(tmp_l[-1]) +', ' + '{:2f}'.format(tmp_lg[-1]) +', ' + '{:2f}'.format(tmp_lc[-1]) +') ' + 
                   '(' + '{:2f}'.format(tmp_d[-1]) +', ' + '{:2f}'.format(tmp_dg[-1]) +', ' + '{:2f}'.format(tmp_dc[-1]) +') ' + 
                   '(' + '{:2f}'.format(tmp_o[-1]) +', ' + '{:2f}'.format(tmp_og[-1]) +', ' + '{:2f}'.format(tmp_oc[-1]) +') ' + 
                   '(' + '{:2f}'.format(tmp_i[-1]) +', ' + '{:2f}'.format(tmp_ig[-1]) +', ' + '{:2f}'.format(tmp_ic[-1]) +') ' )

            loss.backward() 
            optimizer.step()

        training_l.append(np.mean(tmp_l))
        training_a.append(np.mean(tmp_a))
        training_d.append(np.mean(tmp_d))
        training_o.append(np.mean(tmp_o))
        training_i.append(np.mean(tmp_i))
        training_lg.append(np.mean(tmp_lg))
        training_ag.append(np.mean(tmp_ag))
        training_dg.append(np.mean(tmp_dg))
        training_og.append(np.mean(tmp_og))
        training_ig.append(np.mean(tmp_ig))
        training_lc.append(np.mean(tmp_lc))
        training_ac.append(np.mean(tmp_ac))
        training_dc.append(np.mean(tmp_dc))
        training_oc.append(np.mean(tmp_oc))
        training_ic.append(np.mean(tmp_ic))
        learning_rate.append(optimizer.param_groups[-1]['lr'])

        if not lr_scheduler == None:
            lr_scheduler.step(training_l[-1])
        with open(save_path + 'training_l.txt','a+') as f:
            f.write(str(training_l[-1])+'\n')
        with open(save_path + 'training_a.txt','a+') as f:
            f.write(str(training_a[-1])+'\n')
        with open(save_path + 'training_d.txt','a+') as f:
            f.write(str(training_d[-1])+'\n')
        with open(save_path + 'training_o.txt','a+') as f:
            f.write(str(training_o[-1])+'\n')
        with open(save_path + 'training_i.txt','a+') as f:
            f.write(str(training_i[-1])+'\n')
        with open(save_path + 'training_lg.txt','a+') as f:
            f.write(str(training_lg[-1])+'\n')
        with open(save_path + 'training_ag.txt','a+') as f:
            f.write(str(training_ag[-1])+'\n')
        with open(save_path + 'training_dg.txt','a+') as f:
            f.write(str(training_dg[-1])+'\n')
        with open(save_path + 'training_og.txt','a+') as f:
            f.write(str(training_og[-1])+'\n')
        with open(save_path + 'training_ig.txt','a+') as f:
            f.write(str(training_ig[-1])+'\n')
        with open(save_path + 'training_lc.txt','a+') as f:
            f.write(str(training_lc[-1])+'\n')
        with open(save_path + 'training_ac.txt','a+') as f:
            f.write(str(training_ac[-1])+'\n')
        with open(save_path + 'training_dc.txt','a+') as f:
            f.write(str(training_dc[-1])+'\n')
        with open(save_path + 'training_oc.txt','a+') as f:
            f.write(str(training_oc[-1])+'\n')
        with open(save_path + 'training_ic.txt','a+') as f:
            f.write(str(training_ic[-1])+'\n')
        with open(save_path + 'learning_rate.txt','a+') as f:
            f.write(str(learning_rate[-1])+'\n')


        if validation_DataLoader is None:
            continue

        print('')
        print('********** Validation **********')

        model.eval()
        tmp_l = []  
        tmp_a = []  
        tmp_d = []  
        tmp_o = []  
        tmp_i = []  
        tmp_lg = []  
        tmp_ag = []  
        tmp_dg = []  
        tmp_og = []  
        tmp_ig = []  
        tmp_lc = []  
        tmp_ac = []  
        tmp_dc = []  
        tmp_oc = []  
        tmp_ic = []  

        for i, (x, y, z1, z2) in enumerate(validation_DataLoader):
            input, target, ground_target, corrected_target = x.to(device), y.to(device), z1.to(device), z2.to(device)

            with torch.no_grad():
                out = model(input)
                loss = criterion(out, target)

                tmp_l.append(loss.item())
                tmp_a.append(get_a(out,target).item())
                tmp_d.append(get_d(out,target).item())
                tmp_o.append(get_o(out,target).item())
                tmp_i.append(get_i(out,target).item())
                tmp_lg.append(criterion(out,ground_target).item())
                tmp_ag.append(get_a(out,ground_target).item())
                tmp_dg.append(get_d(out,ground_target).item())
                tmp_og.append(get_o(out,ground_target).item())
                tmp_ig.append(get_i(out,ground_target).item())
                tmp_lc.append(criterion(out,corrected_target).item())
                tmp_ac.append(get_a(out,corrected_target).item())
                tmp_dc.append(get_d(out,corrected_target).item())
                tmp_oc.append(get_o(out,corrected_target).item())
                tmp_ic.append(get_i(out,corrected_target).item())

                print( str(epoch) + '/' + str(i) + ' ' +
                       '(' + '{:2f}'.format(tmp_l[-1]) +', ' + '{:2f}'.format(tmp_lg[-1]) +', ' + '{:2f}'.format(tmp_lc[-1]) +') ' + 
                       '(' + '{:2f}'.format(tmp_d[-1]) +', ' + '{:2f}'.format(tmp_dg[-1]) +', ' + '{:2f}'.format(tmp_dc[-1]) +') ' + 
                       '(' + '{:2f}'.format(tmp_o[-1]) +', ' + '{:2f}'.format(tmp_og[-1]) +', ' + '{:2f}'.format(tmp_oc[-1]) +') ' + 
                       '(' + '{:2f}'.format(tmp_i[-1]) +', ' + '{:2f}'.format(tmp_ig[-1]) +', ' + '{:2f}'.format(tmp_ic[-1]) +') ' )

            if epoch==0:
                if not os.path.exists(save_path + 'validation/'):
                    os.mkdir(save_path + 'validation/')
                write_nrrd(save_path + 'validation/', fold['validation'][i] + '_x', input[0,0,:,:,:].cpu().detach().numpy(),[0.15,0.15,0.15],[0,0,0])
                write_seg_nrrd(save_path + 'validation/', fold['validation'][i] + '_y' + noise_level, target[0,0,:,:,:].cpu().detach().numpy(),[0.15,0.15,0.15],[0,0,0])

                # write_seg_nrrd(save_path + 'validation/', fold['validation'][i] + '_y' + noise_level + '', target[0,0,:,:,:].cpu().detach().numpy(),[0.15,0.15,0.15],[0,0,0])

                if noise_level != '':
                    shutil.copy(data_path + fold['validation'][i] + '_y' + noise_level + 'm.nrrd'  , save_path + 'validation/')
                shutil.copy(validation_targets[i], save_path + 'validation/')
                shutil.copy(validation_ground_targets[i], save_path + 'validation/')
                shutil.copy(validation_corrected_targets[i], save_path + 'validation/')

            if epoch%save_freq==0 and epoch>0:
                write_nrrd(save_path + 'validation/', fold['validation'][i] + '_e'+ str(epoch) + '_pred' + noise_level,torch.sigmoid(out[0,0,:,:,:]).cpu().detach().numpy(), [0.15,0.15,0.15], [0,0,0])
                write_seg_nrrd(save_path + 'validation/', fold['validation'][i] + '_e' + str(epoch)+ '_pred' + noise_level, (out[0,0,:,:,:]>=0.0).float().cpu().detach().numpy(),[0.15,0.15,0.15],[0,0,0])

        validation_l.append(np.mean(tmp_l))
        validation_a.append(np.mean(tmp_a))
        validation_d.append(np.mean(tmp_d))
        validation_o.append(np.mean(tmp_o))
        validation_i.append(np.mean(tmp_i))
        validation_lg.append(np.mean(tmp_lg))
        validation_ag.append(np.mean(tmp_ag))
        validation_dg.append(np.mean(tmp_dg))
        validation_og.append(np.mean(tmp_og))
        validation_ig.append(np.mean(tmp_ig))
        validation_lc.append(np.mean(tmp_lc))
        validation_ac.append(np.mean(tmp_ac))
        validation_dc.append(np.mean(tmp_dc))
        validation_oc.append(np.mean(tmp_oc))
        validation_ic.append(np.mean(tmp_ic))


        with open(save_path + 'validation_l.txt','a+') as f:
            f.write(str(validation_l[-1])+'\n')
        with open(save_path + 'validation_a.txt','a+') as f:
            f.write(str(validation_a[-1])+'\n')
        with open(save_path +'validation_d.txt','a+') as f:
            f.write(str(validation_d[-1])+'\n')
        with open(save_path + 'validation_o.txt','a+') as f:
            f.write(str(validation_o[-1])+'\n')
        with open(save_path + 'validation_i.txt','a+') as f:
            f.write(str(validation_i[-1])+'\n')
        with open(save_path + 'validation_lg.txt','a+') as f:
            f.write(str(validation_lg[-1])+'\n')
        with open(save_path + 'validation_ag.txt','a+') as f:
            f.write(str(validation_ag[-1])+'\n')
        with open(save_path + 'validation_dg.txt','a+') as f:
            f.write(str(validation_dg[-1])+'\n')
        with open(save_path + 'validation_og.txt','a+') as f:
            f.write(str(validation_og[-1])+'\n')
        with open(save_path + 'validation_ig.txt','a+') as f:
            f.write(str(validation_ig[-1])+'\n')
        with open(save_path + 'validation_lc.txt','a+') as f:
            f.write(str(validation_lc[-1])+'\n')
        with open(save_path + 'validation_ac.txt','a+') as f:
            f.write(str(validation_ac[-1])+'\n')
        with open(save_path +'validation_dc.txt','a+') as f:
            f.write(str(validation_dc[-1])+'\n')
        with open(save_path + 'validation_oc.txt','a+') as f:
            f.write(str(validation_oc[-1])+'\n')
        with open(save_path + 'validation_ic.txt','a+') as f:
            f.write(str(validation_ic[-1])+'\n')


        if epoch%save_freq==0 and epoch>0:
            if not os.path.exists(save_path + 'models/'):
                os.mkdir(save_path + 'models/')
            torch.save(model,save_path + 'models/model_'+  str(epoch)+'.pt' )

if __name__=='__main__':

    with open('1_out/splits/kidney_right_folds_400.pkl','rb') as fp:
        folds = pickle.load(fp)
    data_path = '1_out/organ/kidney_right/'

    run(folds[0],  data_path,'2_out/kidney_right_ce_fold0_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[0],  data_path,'2_out/kidney_right_ce_fold0_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/kidney_right_ce_fold0_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/kidney_right_ce_fold0_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/kidney_right_ce_fold1_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[1],  data_path,'2_out/kidney_right_ce_fold1_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/kidney_right_ce_fold1_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/kidney_right_ce_fold1_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/kidney_right_ce_fold2_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[2],  data_path,'2_out/kidney_right_ce_fold2_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/kidney_right_ce_fold2_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/kidney_right_ce_fold2_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/kidney_right_ce_fold3_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[3],  data_path,'2_out/kidney_right_ce_fold3_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/kidney_right_ce_fold3_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/kidney_right_ce_fold3_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/kidney_right_ce_fold4_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[4],  data_path,'2_out/kidney_right_ce_fold4_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/kidney_right_ce_fold4_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/kidney_right_ce_fold4_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')

    run(folds[0],  data_path,'2_out/kidney_right_sd_fold0_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[0],  data_path,'2_out/kidney_right_sd_fold0_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/kidney_right_sd_fold0_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/kidney_right_sd_fold0_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/kidney_right_sd_fold1_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[1],  data_path,'2_out/kidney_right_sd_fold1_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/kidney_right_sd_fold1_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/kidney_right_sd_fold1_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/kidney_right_sd_fold2_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[2],  data_path,'2_out/kidney_right_sd_fold2_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/kidney_right_sd_fold2_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/kidney_right_sd_fold2_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/kidney_right_sd_fold3_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[3],  data_path,'2_out/kidney_right_sd_fold3_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/kidney_right_sd_fold3_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/kidney_right_sd_fold3_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/kidney_right_sd_fold4_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[4],  data_path,'2_out/kidney_right_sd_fold4_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/kidney_right_sd_fold4_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/kidney_right_sd_fold4_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')

    with open('1_out/splits/esophagus_folds_400.pkl','rb') as fp:
        folds = pickle.load(fp)
    data_path = '1_out/organ/esophagus/'

    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')

    run(folds[0],  data_path,'2_out/esophagus_sd_fold0_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[0],  data_path,'2_out/esophagus_sd_fold0_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/esophagus_sd_fold0_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/esophagus_sd_fold0_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/esophagus_sd_fold1_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[1],  data_path,'2_out/esophagus_sd_fold1_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/esophagus_sd_fold1_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/esophagus_sd_fold1_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/esophagus_sd_fold2_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[2],  data_path,'2_out/esophagus_sd_fold2_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/esophagus_sd_fold2_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/esophagus_sd_fold2_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/esophagus_sd_fold3_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[3],  data_path,'2_out/esophagus_sd_fold3_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/esophagus_sd_fold3_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/esophagus_sd_fold3_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/esophagus_sd_fold4_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[4],  data_path,'2_out/esophagus_sd_fold4_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/esophagus_sd_fold4_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/esophagus_sd_fold4_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')

    with open('1_out/splits/aorta_folds_400.pkl','rb') as fp:
        folds = pickle.load(fp)
    data_path = '1_out/organ/aorta/'

    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[0],  data_path,'2_out/aorta_ce_fold0_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[1],  data_path,'2_out/aorta_ce_fold1_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[2],  data_path,'2_out/aorta_ce_fold2_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[3],  data_path,'2_out/aorta_ce_fold3_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise0/', 101, noise_level='', lossfunction='ce', corr='')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise1/', 101, noise_level='1', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise2/', 101, noise_level='2', lossfunction='ce',corr='a')
    run(folds[4],  data_path,'2_out/aorta_ce_fold4_noise3/', 101, noise_level='3', lossfunction='ce',corr='a')

    run(folds[0],  data_path,'2_out/aorta_sd_fold0_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[0],  data_path,'2_out/aorta_sd_fold0_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/aorta_sd_fold0_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[0],  data_path,'2_out/aorta_sd_fold0_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/aorta_sd_fold1_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[1],  data_path,'2_out/aorta_sd_fold1_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/aorta_sd_fold1_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[1],  data_path,'2_out/aorta_sd_fold1_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/aorta_sd_fold2_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[2],  data_path,'2_out/aorta_sd_fold2_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/aorta_sd_fold2_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[2],  data_path,'2_out/aorta_sd_fold2_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/aorta_sd_fold3_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[3],  data_path,'2_out/aorta_sd_fold3_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/aorta_sd_fold3_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[3],  data_path,'2_out/aorta_sd_fold3_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/aorta_sd_fold4_noise0/', 101, noise_level='', lossfunction= 'sd',corr='')
    run(folds[4],  data_path,'2_out/aorta_sd_fold4_noise1/', 101, noise_level='1', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/aorta_sd_fold4_noise2/', 101, noise_level='2', lossfunction='sd',corr='d')
    run(folds[4],  data_path,'2_out/aorta_sd_fold4_noise3/', 101, noise_level='3', lossfunction='sd',corr='d')
