import os, pdb
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchnlp.nn import Attention
from torch.nn import Linear, LSTM, GRU
from torch_geometric.nn import RGCNConv, TopKPooling, FastRGCNConv, InnerProductDecoder, RGATConv, GAT
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from .rgcn_sag_pooling import RGCNSAGPooling
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from .pretext.feature_regression import FeatureRegression
from torch_geometric.utils.sparse import dense_to_sparse
import time
import traceback
import math
import torchviz

'''self-supervised implementation of our MRGCN.'''

class SSMRGCN(nn.Module):
    
    def __init__(self, config):
        super(SSMRGCN, self).__init__()
        self.frames_to_pred = config.model_config['frames_to_pred']
        self.probabilistic = config.training_config['probabilistic']
        self.num_relations = config.model_config['num_relations']
        self.num_classes  = config.model_config['nclass']
        self.num_layers = config.model_config['num_layers'] #defines number of RGCN conv layers.
        self.hidden_dim = config.model_config['hidden_dim']
        self.state_length = config.model_config['state_length']
        self.layer_spec = None if config.model_config['layer_spec'] == None else list(map(int, config.model_config['layer_spec'].split(',')))
        self.lstm_dim = config.model_config['lstm_dim']
        # self.lstm_dim2 = config.model_config['lstm_dim']
        self.graph_extraction = config.model_config['graph_extraction']
        self.num_features = config.model_config['num_of_classes'] if self.graph_extraction == 'auto_extraction' else self.state_length
        self.dropout = config.model_config['dropout']
        self.conv_type = config.model_config['conv_type']
        if self.conv_type == "FastRGCNConv":
            self.rgcn_func = FastRGCNConv
        elif self.conv_type == "RGCNConv":
            self.rgcn_func = RGCNConv
        elif self.conv_type == "RGATConv":
            self.rgcn_func = lambda *args: RGATConv(
                *args, 
                edge_dim=self.num_features
                # edge_dim=1
                )
        elif self.conv_type == "GATv2":
                self.rgcn_func = lambda *args: GAT(
                *args, 
                v2=True,
                edge_dim=self.num_features
                # edge_dim=1
                )
        else:
            raise ValueError("conv_type not supported")
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu
        self.pooling_type = config.model_config['pooling_type']
        self.readout_type = config.model_config['readout_type']
        self.temporal_type = config.model_config['temporal_type']
        self.device = config.model_config['device']
        
        self.num_lstm_layers = config.model_config['num_lstm_layers']
        self.GCN_readout = config.training_config['GCN_readout']
        self.load_lane_info = config.training_config['load_lane_info']
        self.temporal_model = config.model_config['temporal_model']
        self.enable_lstm_input_regression = config.training_config['enable_lstm_input_regression']
        self.num_modes = config.model_config['num_modes']
        self.mode_predicter_type = config.model_config['mode_predicter_type']
        assert self.mode_predicter_type in ['statistic', 'regressor'], "Invalid mode_predicter_type"
        self.lstm_latent_type = config.model_config['lstm_latent_type']
        self.pred_future_only = config.model_config['pred_future_only']
        self.prediction_tail_type = config.model_config['prediction_tail_type']

        self.proximity_thresholds = [['near_coll',4],['super_near',7],['very_near',10],['near',16],['visible',25]]
        self.conv = []
        total_dim = 0
        self.fc0_5 = None

        # if self.conv_type in ['FastRGCNConv', 'RGCNConv', 'RGATConv']:
        if self.layer_spec == None:
            if self.num_layers > 0:
                self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(self.device))
                total_dim += self.hidden_dim
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(self.device))
                    total_dim += self.hidden_dim
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
        else:
            if self.num_layers > 0:
                print("using layer specification and ignoring hidden_dim parameter.")
                print("layer_spec: " + str(self.layer_spec))
                self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(self.device))
                total_dim += self.layer_spec[0]
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(self.device))
                    total_dim += self.layer_spec[i]
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
                total_dim += self.hidden_dim
        # elif self.conv_type == 'GATv2':
        #     self.conv = self.rgcn_func(
        #         in_channels=self.num_features, 
        #         hidden_channels=self.hidden_dim, 
        #         num_layers=self.num_layers, 
        #         out_channels=self.hidden_dim, 
        #         v2=True,
        #         dropout=self.dropout,
        #         edge_dim=1).to(self.device)

        self.conv = nn.ModuleList(self.conv) if type(self.conv) == list else self.conv

        if self.pooling_type == "sagpool":
            self.pool1 = RGCNSAGPooling(total_dim, self.num_relations, ratio=config.model_config['pooling_ratio'], rgcn_func=config.model_config['conv_type'])
        elif self.pooling_type == "topk":
            self.pool1 = TopKPooling(total_dim, ratio=config.model_config['pooling_ratio'])

        # TODO: Why tanh
        self.fc1 = nn.Sequential(
            Linear(total_dim, self.lstm_dim * 2),
            nn.ReLU(),
            Linear(self.lstm_dim * 2, self.lstm_dim),
            # nn.ReLU(),
            nn.Tanh()
        )

        if self.prediction_tail_type in ['lstm_tail_probabilistic', 'lstm_tail']:
            self.fc3 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2)
            )

        if self.prediction_tail_type in ["mtp_tail_probabilistic", "lstm_tail_probabilistic"]:
            self.fc4 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2)
            )
        
        if self.temporal_model == 'lstm':
            # lstm_dim = self.lstm_dim * (2 if self.prediction_tail_type == 'mtp_tail_probabilistic' else 1)
            self.lstm = LSTM(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        elif self.temporal_model == 'gru':
            self.lstm = GRU(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        else:
            raise AttributeError("temporal model invalid")

        self.fc2 = nn.Sequential(
            Linear(self.lstm_dim, self.lstm_dim//2),
            nn.ReLU(),
            Linear(self.lstm_dim//2, 2),
        )

        #~~~~~~~~~~~~SS Encoder~~~~~~~~~~~~~~
        #node encoder
        self.feature_extractor1 = nn.Sequential(
            Linear(self.state_length, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features),
            nn.ReLU()
        )

        self.feature_extractor2 = nn.Sequential(
            Linear(self.num_features, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features * self.num_relations), 
        )

        self.lane_extractor = None
        if self.load_lane_info:
            self.lane_extractor = nn.Sequential(
                Linear(2, self.hidden_dim),
                nn.ReLU(),
                Linear(self.hidden_dim, self.num_features),
                nn.ReLU()
            )

        #edge encoder. takes in two node embeddings and returns multilabel edge selection.
        self.f2 = None
        # self.f2 = nn.Sequential(
        #     Linear(2 * 30, self.num_relations),
        #     nn.Sigmoid()
        # )

        if self.mode_predicter_type == 'regressor':
            #Take the last hidden output of the temporal model (h) and output the mode's probability
            self.mode_predicter = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim),
                nn.ReLU(),
                Linear(self.lstm_dim, 1),
                nn.Sigmoid()
            )
        
        #~~~~~~~~SS Pretext Model~~~~~~~~~~~~ #TODO: test performance with/without pretext
        self.enable_pretext = config.training_config['enable_pretext']
        self.pretext = None
        if self.enable_pretext:
            if config.model_config['pretext_model']:
                self.pretext = FeatureRegression(config)
            else:
                raise NotImplementedError
            self.pretext = self.pretext.to(self.device)

        if self.prediction_tail_type == 'mtp_tail':
            self.fc5 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2),
                nn.ReLU(),
                Linear(self.lstm_dim * 2, self.num_modes * (2 * self.frames_to_pred + 1)),
            )
        
        if self.prediction_tail_type == 'mtp_tail_probabilistic':
            if self.mode_predicter_type == 'regressor':
                self.fc6 = nn.Sequential(
                    Linear(self.lstm_dim, self.frames_to_pred + 1)
                )
            else:
                self.fc6 = nn.Sequential(
                    Linear(self.lstm_dim, self.frames_to_pred * 2)
                )

        self.get_model_size()

        # pdb.set_trace()







    def get_model_size(self):
        total_param_size = total_buffer_size = 0
        for model_name in ['conv', 'pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
        # for model_name in ['pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
            param_size, buffer_size = self.print_model_size(model_name)
            total_param_size += param_size
            total_buffer_size += buffer_size

        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in self.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(size_all_mb, param_mb, buffer_mb))
        corrected_param_mb = total_param_size / 1024**2
        corrected_buffer_mb = total_buffer_size / 1024**2
        corrected_size_all_mb = corrected_param_mb + corrected_buffer_mb
        print('model size corrected(?): {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(corrected_size_all_mb, corrected_param_mb, corrected_buffer_mb))


    def print_model_size(self, model_name):
        
        model = eval('self.' + model_name)
        if not model:
            return 0, 0
        print(model)
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('{}: model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(model_name, size_all_mb, param_mb, buffer_mb))
        
        return param_size, buffer_size

    def reparameterization(self, x):
        
        mu = x[:, :self.lstm_dim]
        logvar = x[:, self.lstm_dim:]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        x = eps * std + mu
        
        return x, (mu, logvar), eps
    
    def pdf(self, x, p=1e-5):
        
        """Evaluate the probability density function over standard normal distribution
        x (torch.Tensor): 
        pow (float): A small scaler to avoid overflow.
        """
        return (torch.exp(-(x**2)/2)/math.sqrt(2*math.pi)).pow(p)

    @staticmethod
    def get_proximity_relations(thresholds, distances):
        relations = []
        for d in distances:
            for relation_ind in range(len(thresholds)):
                if d <= thresholds[relation_ind][1]:
                    relations.append(relation_ind)
                    break
                if relation_ind == len(thresholds) -1:
                    relations.append(relation_ind)
        return relations

    @staticmethod
    def get_euclidiean_distance(x1, y1, x2, y2):
        return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)

    
    def motion_prediction(self, graph_embedding):

        mu_values = []
        logvar_values = []
        translation_pred_list = []
        embedding_pred_list = []
        eps_values = []

        if self.prediction_tail_type in ['mtp_tail', 'mtp_tail_probabilistic']:
            x_predicted, hidden = self.lstm(graph_embedding)
            if self.prediction_tail_type == 'mtp_tail_probabilistic':
                x_dis = self.fc4(hidden[0][-1].unsqueeze(0))
                for _ in range(self.num_modes):
                    x, (mu, logvar), eps = self.reparameterization(x_dis)
                    mu_values.append(mu)
                    logvar_values.append(logvar)
                    eps_values.append(eps)
                    translation_pred = self.fc6(x)
                    embedding_pred_list.append(x_predicted) #TODO: Just to make the code work. Delete this after removing all embedding_pred_list
                    translation_pred_list.append(translation_pred.reshape(1, -1))
                if self.mode_predicter_type == 'statistic':
                    mode_probabilities = self.pdf(torch.stack(eps_values, 0))
                    mode_probabilities = torch.prod(mode_probabilities.view((mode_probabilities.shape[0], -1)), -1).unsqueeze(-1).detach() # No need to backpropagate through this; just to fit the mtp loss
                translation_pred_vector = torch.cat(translation_pred_list+[mode_probabilities.view((1, mode_probabilities.shape[0]))], 1)

            elif self.prediction_tail_type == 'mtp_tail':
                mode_probabilities = None
                translation_pred_vector = self.fc5(hidden[0][-1].unsqueeze(0))
                embedding_pred_list.append(hidden[0][-1].unsqueeze(0)) #TODO: Just to make the code work. Delete this after removing all embedding_pred_list
            else:
                raise NotImplementedError('prediction_tail_type not implemented')
            embedding_pred_vector = torch.cat(embedding_pred_list, 0)
        
        elif self.prediction_tail_type in ['lstm_tail_probabilistic', 'lstm_tail']:
            hidden_outputs = []
            for _ in range(self.num_modes):
                eps = []
                x_dis = self.fc3(graph_embedding)
                x, (mu, logvar), e = self.reparameterization(x_dis)
                eps.append(e)
                mu_values.append(mu)
                logvar_values.append(logvar)
                x_predicted, hidden = self.lstm(x)
                embedding_pred = [x_predicted]
                if self.prediction_tail_type == "lstm_tail_probabilistic":
                    for _ in range(1, self.frames_to_pred):
                        if self.temporal_model == 'lstm':
                            x_dix = self.fc4(hidden[0][-1].unsqueeze(0))
                        elif self.temporal_model == 'gru':
                            x_dix = self.fc4(hidden[-1].unsqueeze(0))
                        else:
                            raise NotImplementedError("Invalid type of temporal model")
                        x, (mu, logvar), e = self.reparameterization(x_dix)
                        eps.append(e)
                        mu_values.append(mu)
                        logvar_values.append(logvar)
                        x_predicted, hidden = self.lstm(x, hidden)
                        embedding_pred.append(x_predicted)
                elif self.prediction_tail_type == "lstm_tail":
                    for _ in range(1, self.frames_to_pred):
                        if self.temporal_model == 'lstm':
                            x_predicted, hidden = self.lstm(hidden[0][-1].unsqueeze(0), hidden)
                        elif self.temporal_model == 'gru':
                            x_predicted, hidden = self.lstm(hidden[-1].unsqueeze(0), hidden)
                        else:
                            raise NotImplementedError("Invalid type of temporal model")
                        embedding_pred.append(x_predicted)
                else:
                    raise NotImplementedError("prediction_tail_type not implemented")

                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)
                    
            # if self.enable_lstm_input_regression and self.prediction_tail_type == 'lstm_tail_probabilistic':
            #     embedding_pred = self.fc4(embedding_pred)[:, :self.lstm_dim]
                if self.temporal_model == 'lstm':
                    hidden_outputs.append(hidden[0][-1])
                elif self.temporal_model == 'gru':
                    hidden_outputs.append(hidden[-1])
                else:
                    raise NotImplementedError("Invalid type of temporal model")
                translation_pred_list.append(translation_pred)
                embedding_pred_list.append(embedding_pred)
                eps = torch.cat(eps, 0)
                eps_values.append(eps)
        
            translation_pred_vector = torch.stack(translation_pred_list, 0)
            embedding_pred_vector = torch.stack(embedding_pred_list, 0)
            if self.mode_predicter_type == 'regressor':
                mode_probabilities = self.mode_predicter(torch.stack(hidden_outputs, 0))
            elif self.mode_predicter_type == 'statistic':
                mode_probabilities = self.pdf(torch.stack(eps_values, 0))
                mode_probabilities = torch.prod(mode_probabilities.view((mode_probabilities.shape[0], -1)), -1).unsqueeze(-1)
            else:
                raise NotImplementedError("Type of mode predictor not exist")
        else:
            raise NotImplementedError("prediction_tail_type not implemented")
            
        return translation_pred_vector, embedding_pred_vector, mode_probabilities, mu_values, logvar_values
        # return translation_pred_vector, None, mode_probabilities, mu_values, logvar_values

    def extract_graph(self, state_sequence, lanes_info_sequence):
        try:
            if self.graph_extraction == 'rule_based':
                graph_list = []
                pretext_gt = []
                for seq_ind in range(len(state_sequence)):
                    seq = state_sequence[seq_ind]
                    graph = dict()
                    dist_list = []
                    index_list_from = []
                    index_list_to = []
                    agent_info = seq[0,:]

                    graph['node_embeddings'] = seq


                    for actor_ind in range(seq.shape[0]):
                        if actor_ind == 0:
                            continue
                        actor_info = seq[actor_ind,:]
                        dist = self.get_euclidiean_distance(actor_info[0], actor_info[1], agent_info[0], agent_info[1])
                        dist_list.append(dist)
                        dist_list.append(dist)
                        index_list_from.append(0)
                        index_list_from.append(actor_ind)
                        index_list_to.append(actor_ind)
                        index_list_to.append(0)
                    graph['edge_attr'] = torch.Tensor(self.get_proximity_relations(self.proximity_thresholds, dist_list)).long()
                    graph['edge_index'] = torch.vstack((torch.Tensor(index_list_from), torch.Tensor(index_list_to))).long()
                    # graph['node_embeddings'] = instance_token_sequence[seq_ind]
                    graph_list.append(graph)
            elif self.graph_extraction == 'auto_extraction':
                pretext_gt = []
                graph_list = []
                for i in range(len(state_sequence)):
                    loop_start = time.time()
                    graph = {}
                    node_feature_list = state_sequence[i]
                    # torchviz.make_dot(node_feature_list).render("/home/bar/xb/research/av/graph-motion-prediction/baselines/train/temp/node_feature_list", format="png")
                    node_embeddings = self.feature_extractor1(node_feature_list)
                    
                    if self.load_lane_info:
                        lanes_info = lanes_info_sequence[i]
                        lanes_embeddings_list = []
                        for lane in lanes_info:
                            lane_seg_embeddings = self.lane_extractor(lane)
                            lane_embedding = torch.mean(lane_seg_embeddings, 0)
                            lanes_embeddings_list.append(lane_embedding.unsqueeze(0))
                        lanes_embeddings = torch.cat(lanes_embeddings_list, 0)
                        node_embeddings = torch.cat([node_embeddings, lanes_embeddings], 0)
                    # pdb.set_trace()
                    n_actors = node_embeddings.shape[0]
                    graph['node_embeddings'] = node_embeddings
                    
                    if self.load_lane_info:
                        pretext_gt.append(node_embeddings)
                    else:
                        pretext_gt.append(node_feature_list)

                    end_node_extraction = time.time()

                    if self.conv_type in ['RGCNConv', 'FastRGCNConv']:
                        multi_embeddings = self.feature_extractor2(graph['node_embeddings']).reshape(n_actors, self.num_features, self.num_relations).transpose(1,2).transpose(0,1)
                        multi_embeddings_norm = F.normalize(multi_embeddings, dim=2)
                        embedding_extraction = time.time()
                        adj_matrix = torch.matmul(multi_embeddings_norm, multi_embeddings_norm.transpose(1,2))
                        identity_mask = 1 - (torch.eye(adj_matrix.shape[-1]).unsqueeze(0).repeat(self.num_relations,1,1)).to(self.device)
                        adj_matrix = adj_matrix * identity_mask
                        adj_mask = torch.logical_or(torch.argmax(adj_matrix, dim=0), adj_matrix > 0.5)
                        adj_matrix = adj_matrix * adj_mask.to(self.device)
                        edge_index, edge_attr = dense_to_sparse(adj_matrix)
                        edge_index = torch.remainder(edge_index, n_actors)
                        edge_attr = edge_index[0].div(n_actors, rounding_mode="floor")
                        
                        graph['edge_index'], graph['edge_type'] = edge_index, edge_attr
                        relation_iteration = time.time()
                        graph_list.append(graph)
                        end_loop = time.time()
                    elif self.conv_type in ['RGATConv', 'GATv2']:
                        # pdb.set_trace()
                        embedding_extraction = time.time()
                        multi_embeddings = self.feature_extractor2(graph['node_embeddings']).reshape(n_actors, self.num_features, self.num_relations).transpose(1,2).transpose(0,1)
                        multi_embeddings_norm = F.normalize(multi_embeddings, dim=2)
                        
                        # pdb.set_trace()
                        adj_matrix_attr = torch.einsum('abc,adc->abdc', multi_embeddings_norm, multi_embeddings_norm).reshape(self.num_relations, n_actors, n_actors, self.num_features)
                        adj_matrix = adj_matrix_attr.sum(-1)
                        # adj_matrix = torch.matmul(multi_embeddings_norm, multi_embeddings_norm.transpose(1,2))
                        
                        identity_mask = 1 - (torch.eye(adj_matrix.shape[-1]).unsqueeze(0).repeat(self.num_relations,1,1)).to(self.device)
                        adj_matrix = adj_matrix * identity_mask
                        adj_mask = torch.logical_or(torch.argmax(adj_matrix, dim=0), adj_matrix > 0.5)
                        adj_matrix = adj_matrix * adj_mask.to(self.device)

                        #dense to sparse: same as edge_index, edge_attr = dense_to_sparse(adj_matrix) but save edge_index_raw
                        edge_index_raw = adj_matrix.nonzero().t()
                        edge_attr = adj_matrix[edge_index_raw[0], edge_index_raw[1], edge_index_raw[2]]
                        batch = edge_index_raw[0] * adj_matrix.size(-1)
                        row = batch + edge_index_raw[1]
                        col = batch + edge_index_raw[2]
                        edge_index = torch.stack([row, col], dim=0)


                        edge_index = torch.remainder(edge_index, n_actors)
                        edge_type = edge_index[0].div(n_actors, rounding_mode="floor")
                        # pdb.set_trace()
                        graph['edge_index'], graph['edge_type'] = edge_index, edge_type
                        # graph['edge_attr'] = edge_attr.unsqueeze(-1)
                        graph['edge_attr'] = adj_matrix_attr[edge_index_raw[0], edge_index_raw[1], edge_index_raw[2]]
                        relation_iteration = time.time()
                        graph_list.append(graph)
                        end_loop = time.time()
                    else:
                        raise NotImplementedError("Conv type not implemented")
                    # print('TIMING in extaction-- node extraction: {:.4f}, graph_building: {:.4f}'\
                    #     .format(end_node_extraction-loop_start, end_loop-end_node_extraction))
            else:
                raise NotImplementedError("graph_extraction type not implemented")
        except:
            traceback.print_exc()
            pdb.set_trace()

        return graph_list, pretext_gt


    def forward(self, state_sequence, lanes_info_sequence):
        #graph extraction component
        # pretext_output = torch.zeros(0).to(self.device)
        graph_list = []
        start = time.time()
        # assert torch.count_nonzero(state_sequence[0][0, :2]) == 0, f"the inital translation should always be equal to 0, {torch.count_nonzero(state_sequence[0][0, :3])} instead. {state_sequence[0][:, :3]}"
        # try:
            # pretext_gt = []
            # for i in range(len(state_sequence)):
            #     loop_start = time.time()
            #     graph = {}
            #     node_feature_list = state_sequence[i]
            #     node_embeddings = self.feature_extractor1(node_feature_list)

            #     if self.load_lane_info:
            #         lanes_info = lanes_info_sequence[i]
            #         lanes_embeddings_list = []
            #         for lane in lanes_info:
            #             lane_seg_embeddings = self.lane_extractor(lane)
            #             lane_embedding = torch.mean(lane_seg_embeddings, 0)
            #             lanes_embeddings_list.append(lane_embedding.unsqueeze(0))
            #         lanes_embeddings = torch.cat(lanes_embeddings_list, 0)
            #         node_embeddings = torch.cat([node_embeddings, lanes_embeddings], 0)

            #     n_actors = node_embeddings.shape[0]
            #     graph['node_embeddings'] = node_embeddings
                
            #     if self.load_lane_info:
            #         pretext_gt.append(node_embeddings)
            #     else:
            #         pretext_gt.append(node_feature_list)

            #     end_node_extraction = time.time()
            #     multi_embeddings = self.feature_extractor2(graph['node_embeddings']) \
            #                         .reshape(n_actors, self.num_features, self.num_relations) \
            #                         .transpose(1,2).transpose(0,1)
            #     multi_embeddings = F.normalize(multi_embeddings, dim=2)
            #     embedding_extraction = time.time()
            #     adj_matrix = torch.matmul(multi_embeddings, multi_embeddings.transpose(1,2))

            #     # check the computational graph is correct for adj matrix
            #     identity_mask = 1 - (torch.eye(adj_matrix.shape[-1]).unsqueeze(0).repeat(self.num_relations,1,1)).to(self.device)
            #     adj_matrix = adj_matrix * identity_mask
            #     adj_mask = torch.logical_or(torch.argmax(adj_matrix, dim=0), adj_matrix > 0.5)
            #     adj_matrix = adj_matrix * adj_mask.to(self.device)
            #     edge_index, edge_attr = dense_to_sparse(adj_matrix)
            #     edge_index = torch.remainder(edge_index, n_actors)
            #     edge_attr = edge_index[0].div(n_actors, rounding_mode="floor")
                
            #     graph['edge_index'], graph['edge_attr'] = edge_index, edge_attr
            #     relation_iteration = time.time()
            #     graph_list.append(graph)
            #     end_loop = time.time()
            #     # print('TIMING in extaction-- node extraction: {:.4f}, graph_building: {:.4f}'\
            #     #     .format(end_node_extraction-loop_start, end_loop-end_node_extraction))
            
        # except:
        #     traceback.print_exc()
        #     pdb.set_trace()
        graph_list, pretext_gt = self.extract_graph(state_sequence, lanes_info_sequence)

        if self.conv_type in ['RGCNConv', 'FastRGCNConv']:
            graph_data_list = [Data(x=g['node_embeddings'], edge_index=g['edge_index'], edge_type=g['edge_type']) for g in graph_list]
        elif self.conv_type in ['RGATConv', 'GATv2']:
            graph_data_list = [Data(x=g['node_embeddings'], edge_index=g['edge_index'], edge_type=g['edge_type'], edge_attr=g['edge_attr']) for g in graph_list]
        else:
            raise NotImplementedError("Conv type not implemented")

        train_loader = DataLoader(graph_data_list, batch_size=len(graph_data_list))
        sequence = next(iter(train_loader)).to(self.device)
        if self.conv_type in ['RGCNConv', 'FastRGCNConv']:
            x, edge_index, edge_type, batch = sequence.x, sequence.edge_index, sequence.edge_type, sequence.batch
        elif self.conv_type in ['RGATConv', 'GATv2']:
            x, edge_index, edge_type, edge_attr, batch = sequence.x, sequence.edge_index, sequence.edge_type, sequence.edge_attr, sequence.batch
        else:
            raise NotImplementedError("Conv type not implemented")
        extract = time.time()

        pretext_start = time.time()
        #pretext model
        pretext_output = None
        if self.enable_pretext:
            pretext_output = self.pretext(x, edge_index, edge_attr)
        pretext_end = time.time()
        MRGCN_start = time.time()
        #MRGCN component. downstream task 
        if self.conv_type in ['RGCNConv', 'FastRGCNConv', 'RGATConv', 'GATv2']:
            attn_weights = dict()
            outputs = []
            if self.num_layers > 0:
                for i in range(self.num_layers):
                    if self.conv_type in ['RGCNConv', 'FastRGCNConv']:
                        x = self.activation(self.conv[i](x, edge_index=edge_index, edge_type=edge_type))
                    elif self.conv_type == 'RGATConv':
                        x = self.activation(self.conv[i](x, edge_index=edge_index, edge_type=edge_type, edge_attr=edge_attr))
                    elif self.conv_type == 'GATv2':
                        x = self.activation(self.conv[i](x, edge_index=edge_index, edge_attr=edge_attr))
                    else:
                        raise NotImplementedError("Conv type not implemented")
                    x = F.dropout(x, self.dropout, training=self.training)
                    outputs.append(x)
                x = torch.cat(outputs, dim=-1)
            else:
                x = self.activation(self.fc0_5(x))
        # elif self.conv_type == 'GATv2':
        #     x = self.conv(x, edge_index=edge_index, edge_attr=edge_attr)
        else:
            raise NotImplementedError("Conv type not implemented")

        # pdb.set_trace()

        graph_pooling = time.time()
        #self-attention graph pooling
        # if self.pooling_type == "sagpool":
        #     x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        # elif self.pooling_type == "topk":
        #     x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        # else: 
        #     attn_weights['batch'] = batch

        if self.GCN_readout:
            if self.readout_type == "add":
                x = global_add_pool(x, attn_weights['batch'])
            elif self.readout_type == "mean":
                x = global_mean_pool(x, attn_weights['batch'])
            elif self.readout_type == "max":
                x = global_max_pool(x, attn_weights['batch'])
            else:
                raise NotImplementedError("readout type not implemented")
        else:
            agent_index = torch.cat((torch.Tensor([True]).to(self.device), (batch[1:] - batch[:-1]) > 0)).bool()
            x = x[agent_index]
        
        MRGCN_end = time.time()
        lstm_start = time.time()

        graph_embedding = self.fc1(x)

        translation_pred_vector, embedding_pred_vector, mode_probabilities, mu_values, logvar_values = self.motion_prediction(graph_embedding)
        #TODO @barry: this entire section below here needs some comments to explain what's going on
            
        lstm_end = time.time()
        gcn = time.time()
        # print(len(sequence))
        # print('TIMING-- total: {}, extract: {}, pretext: {}, mrgcn: {}'.format(gcn-start, extract-start, pretext-extract, gcn-pretext))
        # print('TIMING-- total: {}, extract: {}, pretext_end: {}, MRGCN:{}, Graph_pooling:{},  lstm:{}'\
                # .format(gcn-start, extract-start, pretext_end-pretext_start, graph_pooling-MRGCN_start, MRGCN_end-graph_pooling, lstm_end-lstm_start))
        # pdb.set_trace()
        # return (translation_pred_vector, mode_probabilities, embedding_pred_vector, graph_embedding), \
        #     (pretext_output, torch.cat(pretext_gt, 0)), (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if self.probabilistic else None


        # embedding_pred_vector is not used any more
        # pdb.set_trace()
        if self.graph_extraction == 'auto_extraction':
            return (translation_pred_vector, mode_probabilities, embedding_pred_vector, graph_embedding), \
                (pretext_output, torch.cat(pretext_gt, 0)), (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if mu_values else None
        elif self.graph_extraction == 'rule_based':
            return (translation_pred_vector, mode_probabilities, embedding_pred_vector, graph_embedding), \
            (None, None), (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if mu_values else None
        else:
            raise NotImplementedError('graph extraction method not implemented')

        
