import numpy as np
import scipy.ndimage
import nrrd
import numpy as np
import os

def dice(x,y):
    return 2*np.sum(x*y)/(np.sum(x)+np.sum(y))

def gen_opt_dice_seg(m):
    psi = np.flip(np.sort(m.flatten()))
    d = 2*np.cumsum(psi)/(np.sum(m)+np.arange(1,len(psi)+1))
    t = np.max(d)/2
    s = 1.0*(m>=t)
    return s

def gen_opt_acc_seg(m):
    return 1.0*(m>=0.5)

def get_deformation_sample(l,a):
    b = 0.15*np.sqrt(2)
    scaled_a = np.array(a)*np.array(l.shape)
    scaled_b = np.array(b)*np.array(l.shape)
    weight = (2*np.pi)**(3/4)*np.prod(np.sqrt(scaled_b))

    perb = np.array([scipy.ndimage.gaussian_filter(np.random.normal(size=l.shape),scaled_b[i]/np.sqrt(2),mode='constant')*scaled_a[i]*weight for i in range(len(l.shape))])
    grid_mesh = np.meshgrid(*[range(l.shape[i]) for i in range(len(l.shape))],indexing='ij')
    deformation = np.round(scipy.ndimage.map_coordinates(l,grid_mesh+perb, mode='nearest'))

    return deformation

def get_deformation_marginals(l,a):
    scaled_a = np.array(a)*np.array(l.shape)
    marginals = scipy.ndimage.gaussian_filter(l,scaled_a,mode='constant')

    return marginals

def write_nrrd(path, filename, data, pixel_size, corner_center):
    data = np.transpose(data)

    COMPRESSION_LEVEL = 7

    # Transform in mm
    # Convert to numpy first to ensure multiplication is done as expected
    pixel_size_mm = 10 * np.array(pixel_size)
    corner_center_mm = 10 * np.array(corner_center)

    header = {
        "type": "float",
        "dimension": 3,
        "space": "left-posterior-superior",
        "space directions": [
            [pixel_size_mm[2], 0, 0],
            [0, pixel_size_mm[1], 0],
            [0, 0, pixel_size_mm[0]],
        ],
        "kinds": ["domain", "domain", "domain"],
        "encoding": "gzip",
        "space origin": corner_center_mm[::-1],
    }

    nrrd_file_path = os.path.join(path, filename + ".nrrd")
    nrrd.write(nrrd_file_path, data, header, compression_level=COMPRESSION_LEVEL)



def write_seg_nrrd(
    path, filename, data, pixel_size, corner_center, roi_names_dict=None
):

    COMPRESSION_LEVEL = 7
    data = data.astype(np.uint8)
    if len(data.shape) == 3:
        data = np.expand_dims(data, -1)
    data = np.transpose(data)

    # Transform in mm
    # Convert to numpy first to ensure multiplication is done as expected
    pixel_size_mm = 10 * np.array(pixel_size)
    corner_center_mm = 10 * np.array(corner_center)

    header = {
        "type": "uint8",
        "dimension": 3,
        "space": "left-posterior-superior",
        "space directions": [
            None,
            [pixel_size_mm[2], 0, 0],
            [0, pixel_size_mm[1], 0],
            [0, 0, pixel_size_mm[0]],
        ],
        "kinds": ["list", "domain", "domain", "domain"],
        "encoding": "gzip",
        "space origin": corner_center_mm[::-1],
    }

    for segment_id in range(data.shape[0]):
        roi_id = segment_id + 1
        pre = "Segment" + str(segment_id) + "_"
        segment_name = str(roi_id)
        if roi_names_dict is not None:
            roi_name = roi_names_dict.get(roi_id)
            if roi_name is not None:
                segment_name += " - " + roi_name
        header[pre + "ColorAutoGenerated"] = 1
        header[pre + "ID"] = roi_id
        header[pre + "LabelValue"] = 1
        header[pre + "Layer"] = segment_id
        header[pre + "Name"] = segment_name
        header[pre + "NameAutoGenerated"] = 0
        header[pre + "Tags"] = ""

    nrrd_file_path = os.path.join(path, filename + ".seg.nrrd")
    nrrd.write(nrrd_file_path, data, header, compression_level=COMPRESSION_LEVEL)