import argparse

import scipy
import os
import numpy as np
import cv2
import json
import imageio
import torch
from sklearn import metrics

import subprocess
from scipy.ndimage import gaussian_filter
from matplotlib import pyplot as plt
plt.style.use('seaborn')

from tqdm import tqdm
from utils import *
import time

# red_tr    = get_alpha_cmap('Reds')

from models.submodular_cub_v2 import CubSubModularExplanationV2
# from models.submodular_general_with_OOD import CubSubModularExplanationV2

# import torch
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")  
# print(f"CUDA Device: {torch.cuda.current_device()}")  
# print(f"CUDA Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

def parse_args():
    parser = argparse.ArgumentParser(description='Submodular Explanation')
    # general
    parser.add_argument('--Datasets',
                        type=str,
                        default='datasets/celeb-a/test_transformed',
                        help='Datasets.')
    parser.add_argument('--eval-list',
                        type=str,
                        default='datasets/celeb-a/eval.txt',
                        help='Datasets.')
    parser.add_argument('--partition',
                        type=str,
                        default="slico",
                        choices=["grad", "pixel", "slico", "seeds"],
                        help="Partition strategy to use.")
    parser.add_argument('--pixel-partition-number',
                        type=int,
                        default=112,
                        help="")
    parser.add_argument('--grad-partition-size',
                        type=int,
                        default=10,
                        help="")
    parser.add_argument('--grad-number-per-set',
                        type=int,
                        default=4,
                        help="")
    parser.add_argument('--random-patch',
                        type=bool,
                        default=False,
                        help="")
    parser.add_argument('--random-patch-number', 
                        type=int, default=14,
                        help='')
    parser.add_argument('--explanation-method', 
                        type=str, 
                        default='explanation_results_Celeb-A/celeba/HsicAttributionMethod',
                        help='Save path for saliency maps generated by interpretability methods.')
    parser.add_argument('--sub-n', 
                        type=int, default=1,
                        help='')
    parser.add_argument('--sub-k', 
                        type=int, default=50,
                        help='')
    parser.add_argument('--lambda1', 
                        type=float, default=1.,
                        help='')
    parser.add_argument('--lambda2', 
                        type=float, default=1.,
                        help='')
    parser.add_argument('--lambda3', 
                        type=float, default=1.,
                        help='')
    parser.add_argument('--lambda4', 
                        type=float, default=1.,
                        help='')
    parser.add_argument('--cfg', 
                        type=str, 
                        default="configs/cub/submodular_cfg_cub_tf-resnet-v2.json",
                        help='')
    parser.add_argument('--save-dir', 
                        type=str, default='submodular_results_celeba_SWAG_0.85_slico_k_49/celeba',
                        help='output directory to save results')
    args = parser.parse_args()
    return args

def Partition_image(image, explanation_mask, partition_number=112):

    b,g,r = cv2.split(image)

    explanation_mask_flatten = explanation_mask.flatten()
    
    index = np.argsort(-explanation_mask_flatten)
    
    pixels_per_partition = int(len(explanation_mask_flatten) / partition_number)
    
    components_image_list = []
    components_index_list = []
    for i in range(partition_number):
        b_tmp = b.flatten()
        g_tmp = g.flatten()
        r_tmp = r.flatten()
        
        cp_index = np.zeros_like(b_tmp)
        cp_index[index[i*pixels_per_partition: (i+1)*pixels_per_partition]] = 1
        
        b_tmp[index[ : i*pixels_per_partition]] = 0
        g_tmp[index[ : i*pixels_per_partition]] = 0
        r_tmp[index[ : i*pixels_per_partition]] = 0

        b_tmp[index[(i+1)*pixels_per_partition :]] = 0
        g_tmp[index[(i+1)*pixels_per_partition :]] = 0
        r_tmp[index[(i+1)*pixels_per_partition :]] = 0

        b_tmp = b_tmp.reshape((image.shape[0], image.shape[1]))
        g_tmp = g_tmp.reshape((image.shape[0], image.shape[1]))
        r_tmp = r_tmp.reshape((image.shape[0], image.shape[1]))
        cp_index = cp_index.reshape((image.shape[0], image.shape[1]))
        
        img_tmp = cv2.merge([b_tmp, g_tmp, r_tmp])
        components_image_list.append(img_tmp)
        components_index_list.append(cp_index)
    return components_image_list#, components_index_list

def partition_by_mulit_grad(image, explanation_mask, grad_size = 28, grad_num_per_set = 8):
    """
    Divide the image into grad_size x grad_size areas, divide according to eplanation_mask, each division has grad_num_per_set grads.
    """
    partition_number = int(grad_size * grad_size / grad_num_per_set)
    # pixel_length_per_grad = int(image.shape[0] / grad_size)

    components_image_list = []
    pool_z = cv2.resize(explanation_mask, (grad_size, grad_size))

    pool_z_flatten = pool_z.flatten()
    index = np.argsort(- pool_z_flatten)     # From high to low

    for i in range(partition_number):
        binary_mask = np.zeros_like(index)
        binary_mask[index[i*grad_num_per_set : (i+1)*grad_num_per_set]] = 1
        binary_mask = binary_mask.reshape((grad_size, grad_size, 1))
        binary_mask = cv2.resize(
            binary_mask, (image.shape[0],image.shape[1]), interpolation=cv2.INTER_NEAREST)

        components_image_list.append(
            (image * binary_mask[:, :, np.newaxis]).astype(np.uint8)
        )
        
    return components_image_list

def SubRegionDivision(image, mode="slico"):
    import cv2.ximgproc  # Ensure this import is inside or at the top

    element_sets_V = []

    if mode == "slico":
        slic = cv2.ximgproc.createSuperpixelSLIC(image, region_size=30, ruler=20.0)
        slic.iterate(20)
        label_slic = slic.getLabels()
        number_slic = slic.getNumberOfSuperpixels()

        for i in range(number_slic):
            img_copp = image.copy()
            img_copp = img_copp * (label_slic == i)[:, :, np.newaxis]
            element_sets_V.append(img_copp)

    elif mode == "seeds":
        seeds = cv2.ximgproc.createSuperpixelSEEDS(image.shape[1], image.shape[0], image.shape[2],
                                                   num_superpixels=50, num_levels=3)
        seeds.iterate(image, 10)
        label_seeds = seeds.getLabels()
        number_seeds = seeds.getNumberOfSuperpixels()

        for i in range(number_seeds):
            img_copp = image.copy()
            img_copp = img_copp * (label_seeds == i)[:, :, np.newaxis]
            element_sets_V.append(img_copp)

    return element_sets_V


def Partition_by_patch(image, partition_size=10):
    pixel_length = int(image.shape[0] / partition_size)
    
    components_image_list = []
    for i in range(partition_size):
        for j in range(partition_size):
            image_tmp = np.zeros_like(image)
            image_tmp[i*pixel_length : (i+1)*pixel_length, j*pixel_length : (j+1)*pixel_length] = image[i*pixel_length : (i+1)*pixel_length, j*pixel_length : (j+1)*pixel_length]
            
            components_image_list.append(image_tmp)
    return components_image_list

def visualization(image, submodular_image_set, saved_json_file, save_path):
    insertion_ours_images = []
    deletion_ours_images = []
    if len(submodular_image_set) == 0:
        print(f"Warning: Empty submodular_image_set for {save_path}")
        return
    insertion_image = submodular_image_set[0]
    insertion_ours_images.append(insertion_image)
    deletion_ours_images.append(image - insertion_image)
    for smdl_sub_mask in submodular_image_set[1:]:
        insertion_image = insertion_image.copy() + smdl_sub_mask
        insertion_ours_images.append(insertion_image)
        deletion_ours_images.append(image - insertion_image)
    insertion_ours_images_input_results = np.array(saved_json_file["consistency_score"]) * 1e4  # Scale by 10^4
    if len(insertion_ours_images_input_results) == 0:
        print(f"Warning: Empty consistency_score for {save_path}")
        return
    # ours_best_index = np.argmax(insertion_ours_images_input_results)
    ours_best_index = next((i for i, score in enumerate(insertion_ours_images_input_results) if score > 0.85 * 1e4), np.argmax(insertion_ours_images_input_results))
    x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images]
    i = len(x)
    fig, [ax2, ax3] = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 1.5]}, figsize=(24, 8))
    ax2.spines["left"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.spines["bottom"].set_visible(False)
    ax2.xaxis.set_visible(False)
    ax2.yaxis.set_visible(False)
    ax2.set_title('Ours', fontsize=54)
    ax2.set_facecolor('white')
    ax3.set_xlim((0, 1))
    ax3.set_ylim((0, insertion_ours_images_input_results.max() * 1.1))  # Adjust y-limit for scaled values
    ax3.tick_params(axis='both', labelsize=36)
    ax3.set_title('Insertion', fontsize=54)
    ax3.set_ylabel('Recognition Score', fontsize=44)  # Updated label
    ax3.set_xlabel('Percentage of image revealed', fontsize=44)
    x_ = x[:i]
    ours_y = insertion_ours_images_input_results[:i]
    ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5)
    ax3.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54)
    kernel = np.ones((3, 3), dtype=np.uint8)
    ax3.plot([x_[ours_best_index], x_[ours_best_index]], [0, insertion_ours_images_input_results.max() * 1.1], color='red', linewidth=3.5)
    mask = (image - insertion_ours_images[ours_best_index]).mean(-1)
    mask[mask > 0] = 1
    dilate = cv2.dilate(mask, kernel, iterations=3)
    edge = dilate - mask
    image_debug = image.copy()
    image_debug[mask > 0] = image_debug[mask > 0] * 0.5
    image_debug[edge > 0] = np.array([0, 0, 255])
    ax2.imshow(image_debug[..., ::-1])
    auc_score = metrics.auc(x, insertion_ours_images_input_results)
    print(f"Highest confidence: {insertion_ours_images_input_results.max()}\nFinal confidence: {insertion_ours_images_input_results[-1]}\nInsertion AUC: {auc_score}")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)

# def main(args):
    
#     smdl = CubSubModularExplanationV2(cfg_path=args.cfg, k=args.sub_k, 
#                                     lambda1=args.lambda1, lambda2=args.lambda2, lambda3=args.lambda3, lambda4=args.lambda4)
    
#     with open(args.eval_list, "r") as f:
#         infos = f.read().split('\n')
        
#     # infos = [info for info in infos if info.strip() != "" and int(info.split("/")[0]) >= 83]
    
#     os.makedirs(args.save_dir, exist_ok = True)
#     if args.random_patch:
#         save_dir = os.path.join(args.save_dir, "random_patch-{}x{}".format(args.random_patch_number, args.random_patch_number) + "-" + str(args.sub_k))
#     else:
#         if args.partition == "pixel":
#             save_dir = os.path.join(args.save_dir, args.partition + "-set_num_{}".format(args.pixel_partition_number))
#         elif args.partition == "grad":
#             save_dir = os.path.join(args.save_dir, args.partition + "-{}x{}-{}".format(args.grad_partition_size, args.grad_partition_size, args.grad_number_per_set))
#         os.makedirs(save_dir, exist_ok = True)
#         save_dir = os.path.join(save_dir, args.explanation_method.split("/")[-1] + "-" + str(args.sub_k) + "-{}-{}-{}-{}".format(args.lambda1, args.lambda2, args.lambda3, args.lambda4))
#     os.makedirs(save_dir, exist_ok = True)
    
#     for info in tqdm(infos[:]):
#         id_people = info.split(" ")[-1]
#         # save_people_path = os.path.join(save_dir, id_people)
#         # mkdir(save_people_path)
        
#         image_relative_path = info.split(" ")[0]
        
#         if not args.random_patch:
#             mask_path = os.path.join(args.explanation_method, image_relative_path.replace(".jpg", ".npy"))
            
#             mask = np.load(mask_path)
        
#         # Ground Truth Label
#         # gt_label = int(id_people)
        
#         # Read original image
#         image_path = os.path.join(args.Datasets, image_relative_path)
#         image = cv2.imread(image_path)
#         image = cv2.resize(image, (224, 224))
        
#         if args.random_patch:
#             components_image_list = Partition_by_patch(image, args.random_patch_number)
#         else:
#             if args.partition == "pixel":
#                 components_image_list = Partition_image(image, mask, args.pixel_partition_number)
#             elif args.partition == "grad":
#                 components_image_list = partition_by_mulit_grad(image, mask, args.grad_partition_size, args.grad_number_per_set)

#         start = time.time()
#         submodular_image, submodular_image_set, saved_json_file = smdl(components_image_list)
#         end = time.time()
#         print('程序执行时间: ',end - start)
        
#         # Save the final image
#         save_image_root_path = os.path.join(save_dir, "image-{}".format(args.sub_k))
#         os.makedirs(save_image_root_path, exist_ok = True)
#         os.makedirs(os.path.join(save_image_root_path, id_people), exist_ok = True)
#         save_image_path = os.path.join(
#             save_image_root_path, image_relative_path)
#         cv2.imwrite(save_image_path, submodular_image)

#         # Save npy file
#         save_npy_root_path = os.path.join(save_dir, "npy")
#         os.makedirs(save_npy_root_path, exist_ok = True)
#         os.makedirs(os.path.join(save_npy_root_path, id_people), exist_ok = True)
#         np.save(
#             os.path.join(save_npy_root_path, image_relative_path.replace(".jpg", ".npy")),
#             np.array(submodular_image_set)
#         )

#         # Save json file
#         save_json_root_path = os.path.join(save_dir, "json")
#         os.makedirs(save_json_root_path, exist_ok = True)
#         os.makedirs(os.path.join(save_json_root_path, id_people), exist_ok = True)
#         with open(os.path.join(save_json_root_path, image_relative_path.replace(".jpg", ".json")), "w") as f:
#             f.write(json.dumps(saved_json_file, ensure_ascii=False, indent=4, separators=(',', ':')))

#         # Save GIF
#         save_gif_root_path = os.path.join(save_dir, "gif")
#         os.makedirs(save_gif_root_path, exist_ok = True)
#         save_gif_path = os.path.join(save_gif_root_path, id_people)
#         os.makedirs(save_gif_path, exist_ok = True)

def main(args):
    smdl = CubSubModularExplanationV2(cfg_path=args.cfg, k=50, 
                                    lambda1=args.lambda1, lambda2=args.lambda2, 
                                    lambda3=args.lambda3, lambda4=args.lambda4)
    
    with open(args.eval_list, "r") as f:
        infos = f.read().split('\n')
        
    infos = [info for info in infos if info.strip() != "" and int(info.split(" ")[-1]) >= 0]
    
    os.makedirs(args.save_dir, exist_ok=True)
    if args.random_patch:
        save_dir = os.path.join(args.save_dir, "random_patch-{}x{}".format(args.random_patch_number, args.random_patch_number) + "-" + str(args.sub_k))
    else:
        if args.partition == "pixel":
            save_dir = os.path.join(args.save_dir, args.partition + "-set_num_{}".format(args.pixel_partition_number))
        elif args.partition == "grad":
            save_dir = os.path.join(args.save_dir, args.partition + "-{}x{}-{}".format(args.grad_partition_size, args.grad_partition_size, args.grad_number_per_set))
        elif args.partition in ["slico", "seeds"]:
            save_dir = os.path.join(
                args.save_dir,
                "{}_superpixel".format(args.partition),
                args.explanation_method.split("/")[-1] + "-kauto-{}-{}-{}-{}".format(args.lambda1, args.lambda2, args.lambda3, args.lambda4)
            )
        os.makedirs(save_dir, exist_ok=True)
        save_dir = os.path.join(save_dir, args.explanation_method.split("/")[-1] + "-" + str(args.sub_k) + "-{}-{}-{}-{}".format(args.lambda1, args.lambda2, args.lambda3, args.lambda4))
    os.makedirs(save_dir, exist_ok=True)
    
    for info in tqdm(infos[:]):
        id_people = info.split(" ")[-1]
        image_relative_path = info.split(" ")[0]
        
        if not args.random_patch:
            mask_path = os.path.join(args.explanation_method, image_relative_path.replace(".jpg", ".npy"))
            mask = np.load(mask_path)
        
        # Read original image
        image_path = os.path.join(args.Datasets, image_relative_path)
        image = cv2.imread(image_path)
        image = cv2.resize(image, (224, 224))
        
        if args.random_patch:
            components_image_list = Partition_by_patch(image, args.random_patch_number)
        else:
            if args.partition == "pixel":
                components_image_list = Partition_image(image, mask, args.pixel_partition_number)
            elif args.partition == "grad":
                components_image_list = partition_by_mulit_grad(image, mask, args.grad_partition_size, args.grad_number_per_set)
            elif args.partition in ["slico", "seeds"]:
                components_image_list = SubRegionDivision(image, mode=args.partition)
            smdl.k = len(components_image_list)
            args.sub_k = len(components_image_list)  # For correct folder naming if needed
            # smdl.k = 36
            # args.sub_k = 50
        start = time.time()
        submodular_image, submodular_image_set, saved_json_file = smdl(components_image_list)
        end = time.time()
        print('程序执行时间: ', end - start)
        
        # Save the final image
        save_image_root_path = os.path.join(save_dir, "image-{}".format(args.sub_k))
        os.makedirs(save_image_root_path, exist_ok=True)
        os.makedirs(os.path.join(save_image_root_path, id_people), exist_ok=True)
        save_image_path = os.path.join(save_image_root_path, image_relative_path)
        os.makedirs(os.path.dirname(save_image_path), exist_ok=True)  # Ensure full path exists
        cv2.imwrite(save_image_path, submodular_image)

        # Save npy file
        save_npy_root_path = os.path.join(save_dir, "npy")
        os.makedirs(save_npy_root_path, exist_ok=True)
        os.makedirs(os.path.join(save_npy_root_path, id_people), exist_ok=True)
        save_npy_path = os.path.join(save_npy_root_path, image_relative_path.replace(".jpg", ".npy"))
        os.makedirs(os.path.dirname(save_npy_path), exist_ok=True)  # Ensure full path exists
        np.save(save_npy_path, np.array(submodular_image_set))

        # Save json file
        save_json_root_path = os.path.join(save_dir, "json")
        os.makedirs(save_json_root_path, exist_ok=True)
        os.makedirs(os.path.join(save_json_root_path, id_people), exist_ok=True)
        save_json_path = os.path.join(save_json_root_path, image_relative_path.replace(".jpg", ".json"))
        os.makedirs(os.path.dirname(save_json_path), exist_ok=True)  # Ensure full path exists
        with open(save_json_path, "w") as f:
            f.write(json.dumps(saved_json_file, ensure_ascii=False, indent=4, separators=(',', ':')))

        # Save GIF
        save_gif_root_path = os.path.join(save_dir, "gif")
        os.makedirs(save_gif_root_path, exist_ok=True)
        save_gif_path = os.path.join(save_gif_root_path, id_people)
        os.makedirs(save_gif_path, exist_ok=True)
        save_gif_file_path = os.path.join(save_gif_path, image_relative_path.replace(".jpg", ".gif"))
        os.makedirs(os.path.dirname(save_gif_file_path), exist_ok=True)  # Ensure full path exists
        # Note: GIF saving logic is incomplete in the provided code; add if needed

        # img_frame = submodular_image_set[0][..., ::-1]
        # frames = []
        # frames.append(img_frame)
        # for fps in range(1, submodular_image_set.shape[0]):
        #     img_frame = img_frame.copy() + submodular_image_set[fps][..., ::-1]
        #     frames.append(img_frame)

        # imageio.mimsave(os.path.join(save_gif_root_path, image_relative_path.replace(".jpg", ".gif")), 
        #                       frames, 'GIF', duration=0.0085)  
        save_vis_root_path = os.path.join(save_dir, "visualizations")
        os.makedirs(save_vis_root_path, exist_ok=True)
        os.makedirs(os.path.join(save_vis_root_path, id_people), exist_ok=True)
        save_vis_path = os.path.join(save_vis_root_path, image_relative_path.replace(".jpg", ".png"))
        os.makedirs(os.path.dirname(save_vis_path), exist_ok=True)
        visualization(image, submodular_image_set, saved_json_file, save_vis_path)


if __name__ == "__main__":
    args = parse_args()
    main(args)