import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os, sys, pdb
import torch.nn.functional as F
import seaborn
sys.path.append('../../')
sys.path.append('../')
from networkx.drawing import nx_agraph, nx_pydot
import networkx as nx
from PIL import Image
import torch_geometric
from io import BytesIO

#new visualizer that uses .pt files saved from graph relation analysis code

num_relations=12
rule_relation_names=['isIn', 'inDFrontOf', 'inSFrontOf', 'atDRearOf', 'atSRearOf', 'toLeftOf', 'toRightOf', 'near_coll', 'super_near' , 'very_near', 'near' ,'visible']
COLOR_MAPPING = [['isIn','black'],['near_coll','red'], ['super_near','orange'], ['very_near','yellow'], ['near','purple'], ['visible','green'], ['inDFrontOf','violet'], ['inSFrontOf','violet'], ['atDRearOf','turquoise'], ['atSRearOf','turquoise'], ['toLeftOf','blue'], ['toRightOf','blue']] #define relational edge colors for scenegraph visualization purposes in the format [relation, edge_color]
RELATION_COLORS = {i[0]:i[1] for i in COLOR_MAPPING}
ACTOR_NAMES= ["ego car", 'car','moto','bicycle','ped','lane','light','sign', 'road']
ACTOR_COLORS={"car": 'green', 'lane': 'yellow', 'ego car': 'red', 'road': 'white'}
num_of_classes=9
feature_list = ["type_"+str(i) for i in range(num_of_classes)]


#graphs structure is a list of dictionaries, where each dictionary is a graph.
#each dict contains keys 'node_embeddings', 'edge_attr', and 'edge_index'.
def get_seq_list(dir) -> list:
    seq_list = []
    for file in os.listdir(dir):
        if file.endswith('.pt'):
            seq_list.append(file)
    return seq_list


def load_graphs(path) -> list:
    graphs : list = torch.load(path)
    return graphs


def get_dense_adj(graphs, num_relations) -> torch.Tensor:
    '''
    convert torch geometric sparse adjacency matrix to dense adjacency matrix.
    '''
    adjs = []
    for graph in graphs:
        edge_index = graph['edge_index']
        edge_attr = graph['edge_attr']
        node_embeddings = graph['node_embeddings']
        num_nodes = node_embeddings.shape[0]
        adj = torch.zeros((num_relations, num_nodes, num_nodes))
        for i in range(edge_index.shape[1]):
            adj[edge_attr[i]][edge_index[0][i]][edge_index[1][i]] = 1
        adjs.append(adj)
    return adjs 


def convert_to_networkx(graphs):
    '''converts torch geometric graphs to networkx multi-relation directed graphs.'''
    output_graphs = []
    for graph in graphs:
        G = nx.MultiDiGraph()
        car_idx = 0
        for i in range(graph['node_embeddings'].shape[0]):
            for j in range(len(ACTOR_NAMES)):
                if graph['node_embeddings'][i][j] == 1:
                    if j == 1: 
                        label = ACTOR_NAMES[j] + '_' + str(car_idx)
                        car_idx += 1
                    else:
                        label = ACTOR_NAMES[j]
                    G.add_node('car_'+str(i), label=label, style='filled', fillcolor=ACTOR_COLORS[ACTOR_NAMES[j]])
        for i in range(graph['edge_index'].shape[1]):
            if graph['edge_attr'][i].item() == 3:
                G.add_edge('car_'+str(graph['edge_index'][0][i].item()), 
                       'car_'+str(graph['edge_index'][1][i].item()), 
                       label=graph['edge_attr'][i].item(), 
                       color=RELATION_COLORS[rule_relation_names[graph['edge_attr'][i].item()]])
        output_graphs.append(G)
    return output_graphs

def nx_conversion(graphs):
    g = [torch_geometric.data.Data(x=graph['node_embeddings'], edge_index=graph['edge_index'], edge_attr=graph['edge_attr']) for graph in graphs]
    g = [torch_geometric.utils.convert.to_networkx(g[i]) for i in range(len(g))]
    return g


def visualize_learned_and_orig_graphs(learned_graphs, orig_graphs, image_dir=None, save_path="render"):
    plt.figure(figsize=(36,36)) 
    im = None
    draw_sg(orig_graphs[0], save_path=save_path + "_orig.png")
    draw_sg(learned_graphs[0], save_path=save_path + "_learned.png")



def draw_scenegraph_agraph(sg):
    '''uses pygraphviz'''
    A = nx_agraph.to_agraph(sg) 
    A.layout('dot') 
    img = A.draw(format='png')
    return Image.open(BytesIO(img))


def draw_scenegraph_pydot(sg):
    '''uses pydot'''
    A = nx_pydot.to_pydot(sg)
    img = A.create_png()
    return Image.open(BytesIO(img))


def draw_sg(sg, image = None, save_path = None):
    '''draws scenegraph and associated image.'''
    sg_img = draw_scenegraph_agraph(sg)
    # plt.subplot(1, 2, 1)
    plt.imshow(sg_img)
    plt.title("New Scenegraph")
    plt.axis('off')
    
    # if image is not None:
    #     plt.subplot(1, 2, 2)
    #     img = Image.open(image)
    #     plt.imshow(img)
    #     plt.title("Simulation Image")
    #     plt.axis('off')
    # else:
    #     plt.subplot(1, 2, 2)
    #     img = Image.new(mode = "RGB", size = (200, 200),
    #                         color = (0, 0, 0))
    #     plt.imshow(img)
    #     plt.title("No Associated Simulation Image")
    plt.savefig(save_path)


if __name__ == '__main__':
    run_name = 'floating-rocket-32'
    learned_graphs_folder = 'learned_graphs_'+run_name+'/correct/'
    orig_graphs_folder = 'orig_graphs_'+run_name+'/correct/'
    save_path = 'learned_graphs_'+ run_name + '/render'

    print('loading graphs for run: ', run_name)
    # learned_graphs = [load_graphs(learned_graphs_folder + seq) for seq in get_seq_list(learned_graphs_folder)]
    # orig_graphs = [load_graphs(orig_graphs_folder + seq) for seq in get_seq_list(learned_graphs_folder)]

    # assert len(orig_graphs) == len(learned_graphs)
    # learned_graphs = [i for sublist in learned_graphs for i in sublist] #unroll into single list
    # orig_graphs = [i for sublist in orig_graphs for i in sublist]

    #use this instead to just run a single graph sequence:
    scene = '0_lanechange'
    image_dir = '/media/NAS-Temp/louisccc/av/synthesis_data/1043_carla/'+ scene + '/raw_images/'

    print('loading scene: ', scene)
    learned_graphs = load_graphs(learned_graphs_folder+scene+'.pt')
    orig_graphs = load_graphs(orig_graphs_folder+scene+'.pt')

    # pdb.set_trace()

    print('converting to networkx graphs...')
    learned_graphs = convert_to_networkx(learned_graphs) #convert to networkx graphs
    orig_graphs = [item.g for item in orig_graphs] #extract nx graphs from scenegraph objects
    
    visualize_learned_and_orig_graphs(learned_graphs, orig_graphs, image_dir=image_dir, save_path=save_path)
    print('completed.')