import numpy as np
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks

def split(split_time, predict_result, dataset_name,logger):

    data = np.array([row[2] for row in predict_result])
    hist, bin_edges = np.histogram(data, bins=100)

    if (dataset_name == 'gtsrb'):
        valid_bins = hist > len(predict_result) * 0.00005
    elif (dataset_name == 'cifar10'):
        valid_bins = hist > len(predict_result) * 0.001
    if(split_time == 'TrapModelvEnd'):
        valid_bins = hist > len(predict_result) * 0.00005
    bin_indices = np.digitize(data, bin_edges) - 1
    bin_indices[bin_indices >= len(valid_bins)] = len(valid_bins) - 1
    filtered_data = data[valid_bins[bin_indices]]

    kde = gaussian_kde(filtered_data)
    x = np.linspace(filtered_data.min(), filtered_data.max(), 1000)
    pdf = kde(x)
    peaks_max, _ = find_peaks(pdf)
    peaks_min, _ = find_peaks(-pdf)
    logger.info(f"peaks_max\n{x[peaks_max]}")
    logger.info(f"peaks_min\n{x[peaks_min]}")

    if 'CTMv1' == split_time:
        threshold = x[peaks_max[-1]]
    elif 'CTM' in split_time:
        threshold = x[peaks_min[-1]]
    elif 'PTMv1' in split_time or 'Split_Clean' == split_time :
        threshold = x[peaks_min[0]]
    elif 'PTM' in split_time or 'TrapPre' in split_time:
        threshold = np.max(filtered_data[filtered_data < 5])
    elif 'TrapModel' in split_time:
        threshold = []
        threshold.append(np.max(filtered_data[filtered_data < 5]))
        threshold.append(np.min(filtered_data[filtered_data > 5]))
    logger.info(f"threshold is:{threshold}")

    if 'CTM' in split_time:
        clean_pool_image_paths = [row[3] for row in predict_result if row[2] < threshold]
        poison_pool_image_paths = [row[3] for row in predict_result if row[2] > threshold]
    elif 'PTM' in split_time or 'Split_Clean' == split_time or 'TrapPre' in split_time:
        clean_pool_image_paths = [row[3] for row in predict_result if row[2] > threshold]
        poison_pool_image_paths = [row[3] for row in predict_result if row[2] < threshold]
    elif 'TrapModel' in split_time:
        clean_pool_image_paths = [row[3] for row in predict_result if row[2] < threshold[0]]
        poison_pool_image_paths = [row[3] for row in predict_result if row[2] > threshold[1]]

    logger.info(f"splited clean_pool_image_paths len:{len(clean_pool_image_paths)}")
    logger.info(f"splited poison_pool_image_paths len:{len(poison_pool_image_paths)}")
    return clean_pool_image_paths, poison_pool_image_paths, threshold