from imps import *

# %% Settings/parameters

# ---- paths
pancreas_data = "D:/Data/Pancreas-CT"
pancreas_data_imgs = pancreas_data+"/manifest-1599750808610/Pancreas-CT"
pancreas_data_gt = pancreas_data+"/TCIA_pancreas_labels-02-05-2017"
#
acdc_data = "D:/Data/Automated Cardiac Diagnosis Challenge (ACDC)"
acdc_data_train = acdc_data + '/train'
acdc_data_test = acdc_data + '/test'
acdc_data_true_test = acdc_data + '/true_test'
acdc_data_validation = acdc_data + '/validation'
#
synapse_data = "D:/Data/synapse/Abdomen/RawData"
synapse_data_train = synapse_data + "/Training"
synapse_data_test = synapse_data + "/Testing"
synapse_data_valid= synapse_data + "/Validating"


# ---- images
img_cols, img_rows, col_channels = 256, 256, 1
img_cols, img_rows, col_channels = 224, 224, 1

# 3 classes plus background (ACDC)
n_classes = 3 + 1 

dropout_rate = 0


import os
from glob import glob
import time
import re
import argparse
import nibabel as nib 
import pandas as pd
# from medpy.metric.binary import hd, dc, hd95
import numpy as np

def dc(result, reference):

    result = numpy.atleast_1d(result.astype(numpy.bool))
    reference = numpy.atleast_1d(reference.astype(numpy.bool))
    
    intersection = numpy.count_nonzero(result & reference)
    
    size_i1 = numpy.count_nonzero(result)
    size_i2 = numpy.count_nonzero(reference)
    
    try:
        dc = 2. * intersection / float(size_i1 + size_i2)
    except ZeroDivisionError:
        dc = 1.0
    
    return dc

HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)",
          "Dice RV", "Volume RV", "Err RV(ml)",
          "Dice MYO", "Volume MYO", "Err MYO(ml)"]

#
# Utils functions used to sort strings into a natural order
#
def conv_int(i):
    return int(i) if i.isdigit() else i


def natural_order(sord):
    """
    Sort a (list,tuple) of strings into natural order.

    Ex:

    ['1','10','2'] -> ['1','2','10']

    ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']

    """
    if isinstance(sord, tuple):
        sord = sord[0]
    return [conv_int(c) for c in re.split(r'(\d+)', sord)]


#
# Utils function to load and save nifti files with the nibabel package
#
def load_nii(img_path):
    """
    Function to load a 'nii' or 'nii.gz' file, The function returns
    everyting needed to save another 'nii' or 'nii.gz'
    in the same dimensional space, i.e. the affine matrix and the header

    Parameters
    ----------

    img_path: string
    String with the path of the 'nii' or 'nii.gz' image file name.

    Returns
    -------
    Three element, the first is a numpy array of the image values,
    the second is the affine transformation of the image, and the
    last one is the header of the image.
    """
    nimg = nib.load(img_path)
    return nimg.get_data(), nimg.affine, nimg.header


def save_nii(img_path, data, affine, header):
    """
    Function to save a 'nii' or 'nii.gz' file.

    Parameters
    ----------

    img_path: string
    Path to save the image should be ending with '.nii' or '.nii.gz'.

    data: np.array
    Numpy array of the image data.

    affine: list of list or np.array
    The affine transformation to save with the image.

    header: nib.Nifti1Header
    The header that define everything about the data
    (pleasecheck nibabel documentation).
    """
    nimg = nib.Nifti1Image(data, affine=affine, header=header)
    nimg.to_filename(img_path)


#
# Functions to process files, directories and metrics
#
def metrics(img_gt, img_pred, voxel_size, dset="acdc"):
    """
    Function to compute the metrics between two segmentation maps given as input.

    Parameters
    ----------
    img_gt: np.array
    Array of the ground truth segmentation map.

    img_pred: np.array
    Array of the predicted segmentation map.

    voxel_size: list, tuple or np.array
    The size of a voxel of the images used to compute the volumes.

    Return
    ------
    A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
    Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
    """
    
    # Dice = (2xIntersection)/(Union+Intersection)
    # F1 score
    # https://www.youtube.com/watch?v=AZr64OxshLo


    if img_gt.ndim != img_pred.ndim:
        raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
                         "same dimension, {} against {}".format(img_gt.ndim,
                                                                img_pred.ndim))
    if dset=="acdc":
        res = []
        # Loop on each classes of the input images
        for c in [3, 1, 2]:
            # Copy the gt image to not alterate the input
            gt_c_i = np.copy(img_gt)
            gt_c_i[gt_c_i != c] = 0
    
            # Copy the pred image to not alterate the input
            pred_c_i = np.copy(img_pred)
            pred_c_i[pred_c_i != c] = 0
    
            # Clip the value to compute the volumes
            gt_c_i = np.clip(gt_c_i, 0, 1)
            pred_c_i = np.clip(pred_c_i, 0, 1)
    
            # Compute the Dice
            dice = dc(gt_c_i, pred_c_i)
    
            # Compute volume
            volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
            volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.
    
            res += [dice, volpred, volpred-volgt]
            
        if voxel_size==0:
            res = [res[0], res[3], res[6]]
            
    elif dset=="synapse":
        res = []
        # Loop on each classes of the input images
        for c in [1,2,3,4,5,6,7,8]:
            # Copy the gt image to not alterate the input
            gt_c_i = np.copy(img_gt)
            gt_c_i[gt_c_i != c] = 0
    
            # Copy the pred image to not alterate the input
            pred_c_i = np.copy(img_pred)
            pred_c_i[pred_c_i != c] = 0
    
            # Clip the value to compute the volumes
            gt_c_i = np.clip(gt_c_i, 0, 1)
            pred_c_i = np.clip(pred_c_i, 0, 1)
    
            # Compute the Dice
            dice = dc(gt_c_i, pred_c_i)

    
            res += [dice]


def compute_metrics_on_files(path_gt, path_pred):
    """
    Function to give the metrics for two files

    Parameters
    ----------

    path_gt: string
    Path of the ground truth image.

    path_pred: string
    Path of the predicted image.
    """
    gt, _, header = load_nii(path_gt)
    pred, _, _ = load_nii(path_pred)
    zooms = header.get_zooms()

    name = os.path.basename(path_gt)
    name = name.split('.')[0]
    res = metrics(gt, pred, zooms)
    res = ["{:.3f}".format(r) for r in res]

    formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
    print(formatting.format(*HEADER))
    print(formatting.format(name, *res))


def compute_metrics_on_directories(dir_gt, dir_pred):
    """
    Function to generate a csv file for each images of two directories.

    Parameters
    ----------

    path_gt: string
    Directory of the ground truth segmentation maps.

    path_pred: string
    Directory of the predicted segmentation maps.
    """
    lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order)
    lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order)

    res = []
    for p_gt, p_pred in zip(lst_gt, lst_pred):
        if os.path.basename(p_gt) != os.path.basename(p_pred):
            raise ValueError("The two files don't have the same name"
                             " {}, {}.".format(os.path.basename(p_gt),
                                               os.path.basename(p_pred)))

        gt, _, header = load_nii(p_gt)
        pred, _, _ = load_nii(p_pred)
        zooms = header.get_zooms()
        res.append(metrics(gt, pred, zooms))

    lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt]
    res = [[n,] + r for r, n in zip(res, lst_name_gt)]
    df = pd.DataFrame(res, columns=HEADER)
    df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False)


from tensorflow.keras import backend as K
class WarmUpLearningRateScheduler(keras.callbacks.Callback):
    """
    Warmup learning rate scheduler
    """
    
    def __init__(self, warmup_batches, init_lr, verbose=0):
        """Constructor for warmup learning rate scheduler

        Arguments:
            warmup_batches {int} -- Number of batch for warmup.
            init_lr {float} -- Learning rate after warmup.

        Keyword Arguments:
            verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpLearningRateScheduler, self).__init__()
        self.warmup_batches = warmup_batches
        self.init_lr = init_lr
        self.verbose = verbose
        self.batch_count = 0
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.batch_count = self.batch_count + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        if self.batch_count <= self.warmup_batches:
            lr = self.batch_count*self.init_lr/self.warmup_batches
            K.set_value(self.model.optimizer.lr, lr)
            if self.verbose > 0:
                print('\nBatch %05d: WarmUpLearningRateScheduler setting learning '
                      'rate to %s.' % (self.batch_count + 1, lr))


def sep_gen(data, ismask, seed=seed, batch_size=15, dset="training"):
    if dset=="training":
        if ismask:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=360,
                zoom_range=.2,
                shear_range=.1,
                width_shift_range=.3,
                height_shift_range=.3,
                horizontal_flip=True,
                vertical_flip=True, 
                preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype),
                )
        else:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=360,
                zoom_range=.2,
                shear_range=.1,
                width_shift_range=.3,
                height_shift_range=.3,
                horizontal_flip=True,
                vertical_flip=True,     
                )
    elif dset=="validation":
        if ismask:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
               preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype),
                )
            
        else:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator()
    else:
        raise ValueError( "The argument \"dset\" can either be \"training\" or \"validation\".")
  
    return datagen.flow(data, batch_size=batch_size, seed=seed)



def unite_gen(X, y_4, y_2, y, batch_size, dset):
    gen_X = sep_gen(X, False, batch_size=batch_size, dset=dset)
    gen_y_4 = sep_gen(y_4, True, batch_size=batch_size, dset=dset)
    gen_y_2 = sep_gen(y_2, True, batch_size=batch_size, dset=dset)
    gen_y = sep_gen(y, True, batch_size=batch_size, dset=dset)
    while True:
        yield (gen_X.__next__(), [gen_y_4.__next__().astype("uint8"), gen_y_2.__next__().astype("uint8"), gen_y.__next__().astype("uint8")])


