from pytorch_grad_cam import GradCAM
import torch
from models.modelTwoTailClassifier import TwoTailClassifier_smallerKernel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os
from data.dataLoaderPL import GradCamDataModule
from PIL import Image
import numpy as np

def gradcam(classifier_weights_path, dst_folder_path_gradCam, data_dir_text):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dm = GradCamDataModule(batch_size=1, data_dir_text=data_dir_text)  # using 1 image per bath for the visualization part (gradCam)

    dataloader = dm.dataloader()

    # set the network
    classifier = TwoTailClassifier_smallerKernel()
    targets = [ClassifierOutputTarget(1)]  # XAI map for text areas
    classifier.to(device)
    classifier.load_from_checkpoint(classifier_weights_path)
    classifier.eval()

    # set the target layer for gradCam
    target_layers = [classifier.conv2]

    # Construct the CAM object once, and then re-use it on many images.
    if device == 'cpu':
        cam = GradCAM(model=classifier, target_layers=target_layers, use_cuda=False)
    else:
        cam = GradCAM(model=classifier, target_layers=target_layers, use_cuda=True)

    for batch in dataloader:
        image_with_text, image_with_text_name = \
            batch['image_with_text'].to(device), batch['image_with_text_name']

        # calc the gradCam on images with text and save them
        output = cam(image_with_text, targets=targets)

        img = image_with_text.cpu().numpy()

        output = output * 255
        output = output.astype(np.uint8)
        im = (Image.fromarray(output.squeeze()))
        file_name = image_with_text_name[0]
        file_name = file_name[79:]  # need to change the number to set the same file name as the image
        im.save(os.path.join(dst_folder_path_gradCam, file_name))


def gradcam_ensambel(input_folder_path, output_folder_path):

    # Get the list of image names from one of the folders (assuming they have the same names)
    image_names = os.listdir(input_folder_path[0])

    for image_name in image_names:
        # Initialize empty lists to store pixel arrays from each image
        image_arrays = []

        # Read images from each folder and store pixel arrays
        for folder_path in input_folder_path:
            image_path = os.path.join(folder_path, image_name)
            image = Image.open(image_path).convert('L')  # Convert to grayscale
            image_array = np.array(image)
            image_arrays.append(image_array)

        # Calculate max
        max_image = np.max(image_arrays, axis=0).astype(np.uint8)

        # Save the results in the output folders
        max_image_path = os.path.join(output_folder_path, image_name)
        Image.fromarray(max_image).save(max_image_path)