from imps import *

# %% Data preparation

# ---- ACDC

def get_acdc(path, input_size=(img_cols, img_rows, col_channels)):
    all_imgs = []
    all_gt = []
    all_header = []
    all_affine = []
    info = []
    for root, directories, files in os.walk(path):
        for file in files:
            with open(root +"/Info.cfg") as f:
                lines = f.read().splitlines()
            if ".gz" and "frame" in file: #and file:
                # only the ".gz" files are of interest and also only those 
                # with "frame" as the rest(f.i. "patient001_4d") have no ground truth
                if "_gt" not in file:
                    img_path = root + "/" + file
                    img = nib.load(img_path).get_fdata()
                    all_header.append(nib.load(img_path).header)
                    all_affine.append(nib.load(img_path).affine)                   
                    for idx in range(img.shape[2]):
                        i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
                        all_imgs.append(i)
                        info.append(file[:10]+"_"+"ED") if int(file[16:18])==int(lines[0][3:]) else info.append(file[:10]+"_"+"ES")
                        
                else:
                    img_path = root + "/" + file
                    img = nib.load(img_path).get_fdata()
                    for idx in range(img.shape[2]):
                        i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
                        all_gt.append(i)
            

    data = [all_imgs, all_gt, info]                  
 
      
    data[0] = np.expand_dims(data[0], axis=3)
    if path[-9:] != "true_test": # true_test has no values for the masks
        data[1] = np.expand_dims(data[1], axis=3)
    
    return data, all_affine, all_header


def split_data(data, convert="yes"):
    
    # split data
    if convert=="yes":
        # split train and val
        X_train, X_val, y_train, y_val = train_test_split(data[0], data[1], test_size=0.3, random_state=seed)
        # convert from shape (256, 256) to (256, 256, 1)
        X_train = np.expand_dims(X_train, axis=3)
        y_train = np.expand_dims(y_train, axis=3)
        X_val = np.expand_dims(X_val, axis=3)
        y_val = np.expand_dims(y_val, axis=3)
    elif convert=="no":
        # split val and test
        X_train, X_val, y_train, y_val = train_test_split(data[0], data[1], test_size=2/3, random_state=seed)
    
    return X_train, y_train, X_val, y_val


def convert_masks(y, data="acdc"):
    
    if data == "acdc":
        # initialize
        masks = np.zeros((y.shape[0], y.shape[1], y.shape[2], 4))
        
        for i in range(y.shape[0]):
            masks[i][:,:,0] = np.where(y[i]==0, 1, 0)[:,:,-1] 
            masks[i][:,:,1] = np.where(y[i]==1, 1, 0)[:,:,-1] 
            masks[i][:,:,2] = np.where(y[i]==2, 1, 0)[:,:,-1] 
            masks[i][:,:,3] = np.where(y[i]==3, 1, 0)[:,:,-1]
            
    elif data == "synapse":
        masks = np.zeros((y.shape[0], y.shape[1], y.shape[2], 9))
        
        for i in range(y.shape[0]):
            masks[i][:,:,0] = np.where(y[i]==0, 1, 0)[:,:,-1]  # background
            masks[i][:,:,1] = np.where(y[i]==1, 1, 0)[:,:,-1]  # Aorta
            masks[i][:,:,2] = np.where(y[i]==2, 1, 0)[:,:,-1]  # Gallbladder
            masks[i][:,:,3] = np.where(y[i]==3, 1, 0)[:,:,-1]  # Kidney L
            masks[i][:,:,4] = np.where(y[i]==4, 1, 0)[:,:,-1]  # Kidney R
            masks[i][:,:,5] = np.where(y[i]==5, 1, 0)[:,:,-1]  # Liver
            masks[i][:,:,6] = np.where(y[i]==6, 1, 0)[:,:,-1]  # Pancreas
            masks[i][:,:,7] = np.where(y[i]==7, 1, 0)[:,:,-1]  # Spleen
            masks[i][:,:,8] = np.where(y[i]==8, 1, 0)[:,:,-1]  # Stomach

def get_synapse(path, sam_size=100, input_size=(img_cols, img_rows, col_channels)):
    all_imgs = []
    all_gt = []
    for root, directories, files in os.walk(path):
        for file in files:
            if root[-3:] == "img":
                img_path = root + "/" + file
                img = nib.load(img_path).get_fdata()
                for idx in range(img.shape[2]):
                    i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
                    all_imgs.append(i)
            if root[-5:] == "label":
                img_path = root + "/" + file
                img = nib.load(img_path).get_fdata()
                for idx in range(img.shape[2]):
                    i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
                    all_gt.append(i)
    
  
    data = [all_imgs, all_gt]
    data[0] = np.expand_dims(data[0], axis=3)
    data[1] = np.expand_dims(data[1], axis=3)
    return data
     