import time
import torch
import numpy as np
import glob
import shutil
import os
import colorlog
import random
import six
from six.moves import cPickle
import matplotlib as mpl
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering, DBSCAN
mpl.use('Agg')
import matplotlib.pyplot as plt

def clustering(dt, n_clusters, cluster_alg, event_threshold):
    video_array = dt['video_tensor'].squeeze(0).cpu().numpy()
    if cluster_alg == "agglomerative":
        clustering = AgglomerativeClustering(n_clusters)
    elif cluster_alg == "kmeans":
        clustering = KMeans(n_clusters=n_clusters)
    elif cluster_alg == "spectral":
        clustering = SpectralClustering(n_clusters=n_clusters)
    elif cluster_alg == "dbscan":
        clustering = DBSCAN(eps=7,min_samples=event_threshold)
    cluster_result = clustering.fit(video_array)
    labels = cluster_result.labels_
    if cluster_alg == "dbscan":
        labels = labels + 1
    feature_count = video_array.shape[0]
    time_list = np.zeros((feature_count))
    duration = dt['video_length'][:, 1].item()
    for i in range(feature_count):
        time_list[i] = i * duration / feature_count
    return labels, time_list
def get_events_mid_duration(clustered_labels, time_list, n_clusters, event_threshold):
    event_mid_time = []
    event_duration = []
    initialize_events = []
    event_threshhold = event_threshold
    # while len(initialize_events) == 0:
    #     for i in range(n_clusters):
    #         index = np.where(clustered_labels==i)
    #         refined_index = split_into_sublists(sorted(list(index[0]), reverse=True))
    #         for list_temp in refined_index:
    #             if len(list_temp) > event_threshhold :
    #                 initialize_events.append(list_temp)
    #     event_threshhold -= 1
    for i in range(n_clusters):
        index = np.where(clustered_labels==i)
        refined_index = split_into_sublists(sorted(list(index[0]), reverse=True))
        for list_temp in refined_index:
            if len(list_temp) > event_threshhold :
                initialize_events.append(list_temp)
    # if len(initialize_events) == 0:

    for events in initialize_events:
        event_time_list = time_list[events]
        event_mid_time.append(np.average(event_time_list))
        event_duration.append(event_time_list[0]-event_time_list[-1])
    return event_mid_time, event_duration

def get_proposal_from_cluster(dt, event_mid_time, event_duration):
    proposal_tensor = torch.zeros((dt['video_tensor'].shape[0], len(event_mid_time), 1)).to(dt['video_tensor'].device)
    proposal_tensor_mask = torch.zeros((dt['video_tensor'].shape[0], len(event_mid_time))).to(dt['video_tensor'].device)
    video_duration = dt['video_length'][0, 1].item()
    for i in range(len(event_mid_time)):
        proposal_tensor[:, i, 0] = event_mid_time[i]/video_duration
        # proposal_tensor[:, i, 1] = event_duration[i]/video_duration
        proposal_tensor_mask[:, i] = True
    return proposal_tensor, proposal_tensor_mask
def split_into_sublists(A):
    if not A:
        return []

    result = []
    current_sublist = [A[0]]
    for i in range(1, len(A)):
        if A[i] == A[i-1] - 1:
            current_sublist.append(A[i])
        else:
            result.append(current_sublist)
            current_sublist = [A[i]]
    result.append(current_sublist)
    return result

def decide_two_stage(transformer_input_type, dt, criterion, cnum, event_threshold, cluster_alg):
    n_clusters = cnum
    two_stage = False
    clustered_labels, time_list = clustering(dt, n_clusters, cluster_alg, event_threshold)
    if cluster_alg == "dbscan":
        unique_labels = np.unique(clustered_labels)

        n_clusters = len(unique_labels)
    event_mid_time, event_duration = get_events_mid_duration(clustered_labels, time_list, n_clusters, event_threshold)
    proposals, proposals_mask = get_proposal_from_cluster(dt, event_mid_time, event_duration)
    proposals_mask = None
    disable_iterative_refine = False
    clustering_flag = True
    return two_stage, disable_iterative_refine, proposals, proposals_mask, clustering_flag, clustered_labels