import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os, sys, pdb
import torch.nn.functional as F
import seaborn
sys.path.append('../../')
sys.path.append('../')


num_relations=12
rule_relation_names=['isIn', 'inDFrontOf', 'inSFrontOf', 'atDRearOf', 'atSRearOf', 'toLeftOf', 'toRightOf', 'near_coll', 'super_near' , 'very_near', 'near' ,'visible']
num_of_classes=9
feature_list = ["type_"+str(i) for i in range(num_of_classes)]


#graphs structure is a list of dictionaries, where each dictionary is a graph.
#each dict contains keys 'node_embeddings', 'edge_attr', and 'edge_index'.
def get_seq_list(dir) -> list:
    seq_list = []
    for file in os.listdir(dir):
        if file.endswith('.pt'):
            seq_list.append(file)
    return seq_list


def load_graphs(path) -> list:
    graphs : list = torch.load(path)
    return graphs


def get_dense_adj(graphs, num_relations) -> torch.Tensor:
    '''
    convert torch geometric sparse adjacency matrix to dense adjacency matrix.
    '''
    adjs = []
    for graph in graphs:
        edge_index = graph['edge_index']
        edge_attr = graph['edge_attr']
        node_embeddings = graph['node_embeddings']
        num_nodes = node_embeddings.shape[0]
        adj = torch.zeros((num_relations, num_nodes, num_nodes))
        for i in range(edge_index.shape[1]):
            adj[edge_attr[i]][edge_index[0][i]][edge_index[1][i]] = 1
        adjs.append(adj)
    return adjs


def compare_learned_and_orig_graphs(learned_adjs, orig_adjs, save_path):
    '''
    compute cosine similarity between each pair of graphs and plot average across dataset.
    '''
    cossim_agg = torch.tensor([])
    for learned_adj, orig_adj in zip(learned_adjs, orig_adjs):
        cossim = compute_cosine_similarity(orig_adj, learned_adj)
        cossim_agg = torch.cat([cossim_agg, cossim.unsqueeze(0)], dim=0)
    cossim_avg = torch.mean(cossim_agg, dim=0)
    #seaborn version:
    plt.figure(figsize=(10,8))
    seaborn.heatmap(cossim_avg.numpy(), 
                    yticklabels=rule_relation_names, 
                    xticklabels='auto',
                    annot=True,
                    cmap="Blues")
    plt.xlabel('Learned Relation Index')
    plt.ylabel('Rule-Based Relation Index')
    plt.title("Cosine Similarity")
    plt.savefig(save_path + '.png', dpi=600)    


def compute_cosine_similarity(adj1, adj2) -> torch.Tensor:
    '''
    compute cosine similiarity across adjacency matrix for each relation type.
    returned cossim is [num_r, num_r] where columns=adj1, rows=adj2.
    '''
    cossim = torch.zeros((num_relations, num_relations)) 
    for i in range(num_relations):
        for j in range(num_relations):                                                          #uncomment to div by freq of rule-based rel.
            cossim[i][j] = F.cosine_similarity(adj1[i].reshape(-1), adj2[j].reshape(-1), dim=0) #/ max(torch.sum(adj1[i]).cpu().item(), 1)
    return cossim

def preprocess_orig_graphs(graphs):
    '''manually applies torch_geom conversion to original scenegraphs'''
    sequence = []
    for scenegraph in graphs:
        sg_dict = {}
        new_g = scenegraph.g.copy()
        new_g.remove_node(list(new_g.nodes())[0]) #remove root road since self-learned ones dont have it.
        scenegraph.g = new_g
        node_name2idx = {node: idx for idx, node in enumerate(scenegraph.g.nodes)}
        sg_dict['node_embeddings'] = scenegraph.get_real_image_node_embeddings(feature_list)
        sg_dict['edge_index'], sg_dict['edge_attr'] = scenegraph.get_real_image_edge_embeddings(node_name2idx)
        sequence.append(sg_dict)
    return sequence


if __name__ == '__main__':
    run_name = 'dauntless-violet-11'
    learned_graphs_folder = 'learned_graphs_'+run_name+'/correct/'
    orig_graphs_folder = 'orig_graphs_'+run_name+'/correct/'
    save_path = 'learned_graphs_'+ run_name + '/graph_cossim_heatmap'

    print('loading graphs for run: ', run_name)
    learned_graphs = [load_graphs(learned_graphs_folder + seq) for seq in get_seq_list(learned_graphs_folder)]
    orig_graphs = [load_graphs(orig_graphs_folder + seq) for seq in get_seq_list(learned_graphs_folder)]

    assert len(orig_graphs) == len(learned_graphs)
    learned_graphs = [i for sublist in learned_graphs for i in sublist] #unroll into single list
    orig_graphs = [i for sublist in orig_graphs for i in sublist]

    #use this instead to just run a single graph sequence:
    # learned_graphs = load_graphs('learned_graphs/correct/0_lanechange.pt')
    # orig_graphs = load_graphs('orig_graphs/correct/0_lanechange.pt')

    print('graphs loaded, preprocessing...')
    orig_graphs = preprocess_orig_graphs(orig_graphs)

    print('preprocessing done, computing cosine similarity...')
    learned_adjs = get_dense_adj(learned_graphs, num_relations)
    orig_adjs = get_dense_adj(orig_graphs, num_relations)
    
    compare_learned_and_orig_graphs(learned_adjs, orig_adjs, save_path)
    print('completed.')