from skimage.color.colorconv import rgb2lab
import torch
import numpy as np
from .hist import get_image_hist
from skimage.feature import canny as extract_canny
from scipy.ndimage import gaussian_filter
from scipy.spatial.distance import jensenshannon
from skimage.measure import label as extract_regions

def color_distance(img1, img2):
    hist1, hist2 = get_image_hist(img1), get_image_hist(img2)
    hist1, hist2 = (hist1 / hist1.sum()), (hist2 / hist2.sum())
    hist1, hist2 = gaussian_filter(hist1, 1.0), gaussian_filter(hist2, 1.0)
    return jensenshannon(hist1.ravel(), hist2.ravel())

def _canny_rgb(x):
    canny = 0
    for channel in np.rollaxis(x, -1):
        canny |= extract_canny(channel, 0)
    return canny

def structure_distance(img1, img2):
    img1, img2 = map(np.array, (img1, img2))
    canny1, canny2 = map(_canny_rgb, (img1, img2))
    ep1 = np.stack(np.where(canny1), 1) # [M, 2]
    ep2 = np.stack(np.where(canny2), 1) # [N, 2]
    dismat = np.sqrt(np.power(ep1[:, None] - ep2[None, :], 2).sum(-1)) # [M, N]
    dis = (dismat.min(0).mean() + dismat.min(1).mean()) / 2
    return dis

def flatness_L2_rgb(img, contour):
    img = np.array(img)
    contour = np.array(contour)
    fat_canny = (contour > 128).astype(np.uint8)
    mask, num = extract_regions(fat_canny, 0, True)
    var_all = []
    for n in range(num):
        mi, mj = np.where(mask == n + 1)
        pixels = img[mi, mj]
        var = np.power(pixels - pixels.mean(0, keepdims=True), 2).ravel()
        var_all.append(var)
    return np.concatenate(var_all).mean()


def flatness_L1_rgb(img, contour):
    img = np.array(img)
    contour = np.array(contour)
    fat_canny = (contour > 128).astype(np.uint8)
    mask, num = extract_regions(fat_canny, 0, True)
    total_sum = 0
    total_size = 0
    for n in range(num):
        mi, mj = np.where(mask == n + 1)
        pixels = img[mi, mj]
        variances = np.abs(pixels - pixels.mean(0, keepdims=True)).ravel()
        total_sum += variances.sum()
        total_size += len(variances)
    return total_sum / total_size

def flatness_L1_lab(img, contour):
    img = rgb2lab(np.array(img))
    contour = np.array(contour)
    fat_canny = (contour > 128).astype(np.uint8)
    mask, num = extract_regions(fat_canny, 0, True)
    var_all = []
    for n in range(num):
        mi, mj = np.where(mask == n + 1)
        pixels = img[mi, mj]
        var = np.abs(pixels - pixels.mean(0, keepdims=True))
        var_all.append(var)
    return np.concatenate(var_all).mean()

