import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
import scipy.ndimage
import seaborn as sns
from copy import deepcopy

## This is used to calculate the blur amount based on a polynomial
def get_curve(max_blur,val,order=4):
    val = (1/(max_blur**(order-1)))*(val**order) # (1/max^3)x^4
    return val


def show_heatmap(img,stacked_img):
    plt.cla()
    stacked_img *= (255.0/stacked_img.max()) # Change to range 0 to 255
    stacked_img = gaussian_filter(stacked_img, sigma=5) # Blur boundary between different annotation density levels

    # Generate heatmap of saliency map
    hmax = sns.heatmap(stacked_img,
        alpha = 0.5,
        zorder=2,
        edgecolor="none",
        linewidths=0.0,#,)
        xticklabels=False,
        yticklabels=False,
        cbar=False,
        rasterized=True,
        cmap="jet")
    
    hmax.imshow(img,
          aspect = hmax.get_aspect(),
          extent = hmax.get_xlim() + hmax.get_ylim(),
          zorder = 1) 
    hmax.axes.get_xaxis().set_visible(False)
    hmax.axes.get_yaxis().set_visible(False)
    plt.grid(False)
    plt.tight_layout(pad=0)
    plt.title("Heatmap representation of saliency map")
    plt.show()


def blur_image(im,stacked_img,max_blur,curved):
    blurred_im = deepcopy(im) # to hold final image
    stacked_max = max(stacked_img.max(),1) # Make sure there are no 0 maxes
    stacked_img *= (255.0/stacked_max) # transform to 0 to 255 range
    stacked_img = scipy.ndimage.gaussian_filter(stacked_img, 5) # Blur this image so no hard boundaries between different levels
    scaled_img = stacked_img/(stacked_img.max()/max_blur) # Scale image back to range 0 to sigma max so we get correct blur level

    scaled_img = np.around(scaled_img,decimals=1) # Round to one decimal place
    uniq_vals = np.unique(scaled_img)
    for scaled_val in uniq_vals: # for all possible blur levels
        blur_amount = round(max_blur - scaled_val, 1)
        if curved: # based on polynomial instead of linear
            blur_amount  = get_curve(max_blur,blur_amount,order=4)
        
        if blur_amount != 0: # if blurring required
            indices = np.where(scaled_img == scaled_val) # where this level of blur should be applied
            qwerty = scipy.ndimage.gaussian_filter(im, blur_amount)
            blurred_im[indices] = qwerty[indices] # apply that blur level to the specified pixels
    return blurred_im


if __name__ == "__main__":

    sigma_max = 10 # You can set this to any value
    curved = True # Use a non-linear function to calculate blur

    example_image = cv2.imread('./Example/Images/artificial_eye.png',cv2.IMREAD_GRAYSCALE)
    example_annotations = "./Example/Annotations/"

    stacked_img = np.zeros((480,640))
    for annotation in os.listdir(example_annotations):
        ann_values = cv2.imread(example_annotations + annotation,cv2.IMREAD_GRAYSCALE)
        pixels = np.array(ann_values)
        pixels = np.where(pixels>0, 1, 0) # each pixel is either 0 (unannotated) or 1 (annotated) for each example
        stacked_img = stacked_img + pixels # we add all individual annotations together
    
    saliency_map = cv2.cvtColor((stacked_img * (255.0/stacked_img.max())).astype("uint8"),cv2.COLOR_GRAY2RGB)
    plt.imshow(saliency_map)
    plt.title("Saliency map showing annotation density")
    plt.show()
    
    show_heatmap(example_image,stacked_img) # show heatmap

    # Generate blurred image
    blurred_image = blur_image(example_image,stacked_img,sigma_max,curved) 

    # For display purposes
    color_example_image = cv2.cvtColor(example_image,cv2.COLOR_GRAY2RGB)
    color_blurred_image = cv2.cvtColor(blurred_image,cv2.COLOR_GRAY2RGB)
    combined_image = np.concatenate((color_example_image,color_blurred_image),axis=1)
    plt.cla()
    plt.imshow(combined_image)
    plt.title("Original image vs. Blurred image using saliency map")
    plt.show()

