import os, pdb
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchnlp.nn import Attention
from torch.nn import Linear, LSTM, GRU
from torch_geometric.nn import RGCNConv, TopKPooling, FastRGCNConv, InnerProductDecoder
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from .rgcn_sag_pooling import RGCNSAGPooling
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from .pretext.feature_regression import FeatureRegression
from torch_geometric.utils.sparse import dense_to_sparse
import time
import math
import traceback

'''self-supervised implementation of our MRGCN.'''

class SSMRGCN(nn.Module):
    
    def __init__(self, config):
        super(SSMRGCN, self).__init__()
        self.frames_to_pred = config.model_config['frames_to_pred']
        self.probabilistic = config.training_config['probabilistic']
        self.num_features = config.model_config['num_of_classes']
        self.num_relations = config.model_config['num_relations']
        self.num_classes  = config.model_config['nclass']
        self.num_layers = config.model_config['num_layers'] #defines number of RGCN conv layers.
        self.hidden_dim = config.model_config['hidden_dim']
        self.layer_spec = None if config.model_config['layer_spec'] == None else list(map(int, config.model_config['layer_spec'].split(',')))
        self.lstm_dim = config.model_config['lstm_dim']
        # self.lstm_dim2 = config.model_config['lstm_dim']
        self.rgcn_func = FastRGCNConv if config.model_config['conv_type'] == "FastRGCNConv" else RGCNConv
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu
        self.pooling_type = config.model_config['pooling_type']
        self.readout_type = config.model_config['readout_type']
        self.temporal_type = config.model_config['temporal_type']
        self.device = config.model_config['device']
        self.dropout = config.model_config['dropout']
        self.num_lstm_layers = config.model_config['num_lstm_layers']
        self.GCN_readout = config.training_config['GCN_readout']
        self.load_lane_info = config.training_config['load_lane_info']
        self.temporal_model = config.model_config['temporal_model']
        self.enable_lstm_input_regression = config.training_config['enable_lstm_input_regression']
        self.num_modes = config.model_config['num_modes']
        self.mode_predicter_type = config.model_config['mode_predicter_type']
        assert self.mode_predicter_type in ['statistic', 'regressor'], "Invalid mode_predicter_type"
        self.lstm_latent_type = config.model_config['lstm_latent_type']
        self.pred_future_only = config.model_config['pred_future_only']
        self.prediction_tail_type = config.model_config['prediction_tail_type']
        self.proximity_thresholds = [['near_coll',4],['super_near',7],['very_near',10],['near',16],['visible',25]]
        
        print("load_lane_info:", self.load_lane_info)
        self.conv = []
        total_dim = 0

        self.fc0_5 = None
        if self.layer_spec == None:
            if self.num_layers > 0:
                self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(self.device))
                total_dim += self.hidden_dim
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(self.device))
                    total_dim += self.hidden_dim
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
        else:
            if self.num_layers > 0:
                print("using layer specification and ignoring hidden_dim parameter.")
                print("layer_spec: " + str(self.layer_spec))
                self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(self.device))
                total_dim += self.layer_spec[0]
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(self.device))
                    total_dim += self.layer_spec[i]

            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
                total_dim += self.hidden_dim
        
        self.conv = nn.ModuleList(self.conv)

        if self.pooling_type == "sagpool":
            self.pool1 = RGCNSAGPooling(total_dim, self.num_relations, ratio=config.model_config['pooling_ratio'], rgcn_func=config.model_config['conv_type'])
        elif self.pooling_type == "topk":
            self.pool1 = TopKPooling(total_dim, ratio=config.model_config['pooling_ratio'])

        self.fc1 = nn.Sequential(
            Linear(total_dim, self.lstm_dim * 2),
            nn.ReLU(),
            Linear(self.lstm_dim * 2, self.lstm_dim),
            # nn.ReLU(),
            nn.Tanh()
        )

	
        if self.prediction_tail_type in ['lstm_tail_probabilistic', 'lstm_tail']:
            self.fc3 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2)
            )
        if self.prediction_tail_type in ["mtp_tail_probabilistic", "lstm_tail_probabilistic"]:
            self.fc4 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2)
            )
        
        if self.temporal_model == 'lstm':
            self.lstm = LSTM(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        elif self.temporal_model == 'gru':
            self.lstm = GRU(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        else:
            raise AttributeError("temporal model invalid")

        self.fc2 = nn.Sequential(
            Linear(self.lstm_dim, self.lstm_dim//2),
            nn.ReLU(),
            Linear(self.lstm_dim//2, 2),
        )

        #~~~~~~~~~~~~SS Encoder~~~~~~~~~~~~~~
        #node encoder
        self.feature_extractor1 = nn.Sequential(
            Linear(33, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features),
            nn.ReLU()
        )

        self.feature_extractor2 = nn.Sequential(
            Linear(self.num_features, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features * self.num_relations), 
        )

        self.lane_extractor = None
        if self.load_lane_info:
            self.lane_extractor = nn.Sequential(
                Linear(2, self.hidden_dim),
                nn.ReLU(),
                Linear(self.hidden_dim, self.num_features),
                nn.ReLU()
            )

        #edge encoder. takes in two node embeddings and returns multilabel edge selection.
        self.f2 = None
        # self.f2 = nn.Sequential(
        #     Linear(2 * 30, self.num_relations),
        #     nn.Sigmoid()
        # )

	
        if self.mode_predicter_type == 'regressor':
            #Take the last hidden output of the temporal model (h) and output the mode's probability
            self.mode_predicter = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim),
                nn.ReLU(),
                Linear(self.lstm_dim, 1),
                nn.Sigmoid()
            )

        #~~~~~~~~SS Pretext Model~~~~~~~~~~~~ #TODO: test performance with/without pretext
        self.enable_pretext = config.training_config['enable_pretext']
        self.pretext = None
        if self.enable_pretext:
            if config.model_config['pretext_model']:
                self.pretext = FeatureRegression(config)
            else:
                raise NotImplementedError
            self.pretext = self.pretext.to(self.device)

        self.get_model_size()

        if self.prediction_tail_type == 'mtp_tail':
            self.fc5 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2),
                nn.ReLU(),
                Linear(self.lstm_dim * 2, self.num_modes * (2 * self.frames_to_pred + 1)),
            )
        
        if self.prediction_tail_type == 'mtp_tail_probabilistic':
            if self.mode_predicter_type == 'regressor':
                self.fc6 = nn.Sequential(
                    Linear(self.lstm_dim, self.frames_to_pred + 1)
                )
            else:
                self.fc6 = nn.Sequential(
                    Linear(self.lstm_dim, self.frames_to_pred * 2)
                )


    def get_model_size(self):
        total_param_size = total_buffer_size = 0
        for model_name in [f'conv[{i}]' for i in range(self.num_layers)] + ['pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
        # for model_name in ['pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
            param_size, buffer_size = self.print_model_size(model_name)
            total_param_size += param_size
            total_buffer_size += buffer_size

        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in self.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(size_all_mb, param_mb, buffer_mb))
        corrected_param_mb = total_param_size / 1024**2
        corrected_buffer_mb = total_buffer_size / 1024**2
        corrected_size_all_mb = corrected_param_mb + corrected_buffer_mb
        print('model size corrected(?): {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(corrected_size_all_mb, corrected_param_mb, corrected_buffer_mb))


    def print_model_size(self, model_name):
        model = eval('self.' + model_name)
        if not model:
            return 0, 0
        print(model)
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('{}: model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(model_name, size_all_mb, param_mb, buffer_mb))
        return param_size, buffer_size

    def reparameterization(self, x):
        
        mu = x[:, :self.lstm_dim]
        logvar = x[:, self.lstm_dim:]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        x = eps * std + mu
        
        return x, (mu, logvar), eps
    
    def pdf(self, x, p=1e-5):
        
        """Evaluate the probability density function over standard normal distribution
        x (torch.Tensor): 
        pow (float): A small scaler to avoid overflow.
        """
        return (torch.exp(-(x**2)/2)/math.sqrt(2*math.pi)).pow(p)
    
    def motion_prediction(self, graph_embedding):

        mu_values = []
        logvar_values = []
        translation_pred_list = []
        embedding_pred_list = []
        eps_values = []

        if self.prediction_tail_type in ['mtp_tail', 'mtp_tail_probabilistic']:
            x_predicted, hidden = self.lstm(graph_embedding)
            if self.prediction_tail_type == 'mtp_tail_probabilistic':
                for _ in range(self.num_modes):
                    x = self.fc4(hidden[0][-1].unsqueeze(0))
                    x, (mu, logvar), eps = self.reparameterization(x)
                    mu_values.append(mu)
                    logvar_values.append(logvar)
                    eps_values.append(eps)
                    translation_pred = self.fc6(x)
                    embedding_pred_list.append(x_predicted) #TODO: Just to make the code work. Delete this after removing all embedding_pred_list
                    translation_pred_list.append(translation_pred.reshape(1, -1))
                if self.mode_predicter_type == 'statistic':
                    mode_probabilities = self.pdf(torch.stack(eps_values, 0))
                    mode_probabilities = torch.prod(mode_probabilities.view((mode_probabilities.shape[0], -1)), -1).unsqueeze(-1).detach() # No need to backpropagate through this; just to fit the mtp loss
                translation_pred_vector = torch.cat(translation_pred_list+[mode_probabilities.view((1, mode_probabilities.shape[0]))], 1)

            elif self.prediction_tail_type == 'mtp_tail':
                mode_probabilities = None
                translation_pred_vector = self.fc5(hidden[0][-1].unsqueeze(0))
                embedding_pred_list.append(hidden[0][-1].unsqueeze(0)) #TODO: Just to make the code work. Delete this after removing all embedding_pred_list
            else:
                raise NotImplementedError('prediction_tail_type not implemented')
            embedding_pred_vector = torch.cat(embedding_pred_list, 0)
        
        elif self.prediction_tail_type in ['lstm_tail_probabilistic', 'lstm_tail']:
            hidden_outputs = []
            for _ in range(self.num_modes):
                eps = []
                x = self.fc3(graph_embedding)
                x, (mu, logvar), e = self.reparameterization(x)
                eps.append(e)
                mu_values.append(mu)
                logvar_values.append(logvar)
                x_predicted, hidden = self.lstm(x)
                embedding_pred = [x_predicted]
                if self.prediction_tail_type == "lstm_tail_probabilistic":
                    for _ in range(1, self.frames_to_pred):
                        if self.temporal_model == 'lstm':
                            x = self.fc4(hidden[0][-1].unsqueeze(0))
                        elif self.temporal_model == 'gru':
                            x = self.fc4(hidden[-1].unsqueeze(0))
                        else:
                            raise NotImplementedError("Invalid tyoe of temporal model")
                        x, (mu, logvar), e = self.reparameterization(x)
                        eps.append(e)
                        mu_values.append(mu)
                        logvar_values.append(logvar)
                        x_predicted, hidden = self.lstm(x, hidden)
                        embedding_pred.append(x_predicted)
                elif self.prediction_tail_type == "lstm_tail":
                    for _ in range(1, self.frames_to_pred):
                        if self.temporal_model == 'lstm':
                            x_predicted, hidden = self.lstm(hidden[0][-1].unsqueeze(0), hidden)
                        elif self.temporal_model == 'gru':
                            x_predicted, hidden = self.lstm(hidden[-1].unsqueeze(0), hidden)
                        else:
                            raise NotImplementedError("Invalid tyoe of temporal model")
                        embedding_pred.append(x_predicted)
                else:
                    raise NotImplementedError("prediction_tail_type not implemented")

                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)
                    
            # if self.enable_lstm_input_regression and self.prediction_tail_type == 'lstm_tail_probabilistic':
            #     embedding_pred = self.fc4(embedding_pred)[:, :self.lstm_dim]
                if self.temporal_model == 'lstm':
                    hidden_outputs.append(hidden[0][-1])
                elif self.temporal_model == 'gru':
                    hidden_outputs.append(hidden[-1])
                else:
                    raise NotImplementedError("Invalid tyoe of temporal model")
                translation_pred_list.append(translation_pred)
                embedding_pred_list.append(embedding_pred)
                eps = torch.cat(eps, 0)
                eps_values.append(eps)
        
            translation_pred_vector = torch.stack(translation_pred_list, 0)
            embedding_pred_vector = torch.stack(embedding_pred_list, 0)
            if self.mode_predicter_type == 'regressor':
                mode_probabilities = self.mode_predicter(torch.stack(hidden_outputs, 0))
            elif self.mode_predicter_type == 'statistic':
                mode_probabilities = self.pdf(torch.stack(eps_values, 0))
                mode_probabilities = torch.prod(mode_probabilities.view((mode_probabilities.shape[0], -1)), -1).unsqueeze(-1)
            else:
                raise NotImplementedError("Type of mode predictor not exist")
        else:
            raise NotImplementedError("prediction_tail_type not implemented")
            
        return translation_pred_vector, embedding_pred_vector, mode_probabilities, mu_values, logvar_values
        # return translation_pred_vector, None, mode_probabilities, mu_values, logvar_values

    @staticmethod
    def get_proximity_relations(thresholds, distances):
        relations = []
        for d in distances:
            for relation_ind in range(len(thresholds)):
                if d <= thresholds[relation_ind][1]:
                    relations.append(relation_ind)
                    break
                if relation_ind == len(thresholds) -1:
                    relations.append(relation_ind)
        return relations

    @staticmethod
    def get_euclidiean_distance(x1, y1, x2, y2):
        return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)


    def forward(self, state_sequence, lanes_info_sequence, pred_frames=12):
        #graph extraction component
        # pretext_output = torch.zeros(0).to(self.device)
        try:
            size_batch = 0
            graph_list = []
            pretext_gt = []
            for seq_ind in range(len(state_sequence)):
                seq = state_sequence[seq_ind]
                graph = dict()
                dist_list = []
                index_list_from = []
                index_list_to = []
                agent_info = seq[0,:]

                graph['node_embeddings'] = seq


                for actor_ind in range(seq.shape[0]):
                    if actor_ind == 0:
                        continue
                    actor_info = seq[actor_ind,:]
                    dist = self.get_euclidiean_distance(actor_info[0], actor_info[1], agent_info[0], agent_info[1])
                    dist_list.append(dist)
                    dist_list.append(dist)
                    index_list_from.append(0)
                    index_list_from.append(actor_ind)
                    index_list_to.append(actor_ind)
                    index_list_to.append(0)
                graph['edge_attr'] = torch.Tensor(self.get_proximity_relations(self.proximity_thresholds, dist_list)).long()
                graph['edge_index'] = torch.vstack((torch.Tensor(index_list_from), torch.Tensor(index_list_to))).long()
                # graph['node_embeddings'] = instance_token_sequence[seq_ind]
                
                graph_list.append(graph)
        except:
            traceback.print_exc()
            pdb.set_trace()

        
        graph_data_list = [Data(x=g['node_embeddings'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in graph_list]
        train_loader = DataLoader(graph_data_list, batch_size=len(graph_data_list))
        sequence = next(iter(train_loader)).to(self.device)
        x, edge_index, edge_attr, batch = sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch
        extract = time.time()

        pretext_start = time.time()
        #pretext model
        pretext_output = None
        if self.enable_pretext:
            pretext_output = self.pretext(x, edge_index, edge_attr)
        pretext_end = time.time()
        MRGCN_start = time.time()
        #MRGCN component. downstream task 
        attn_weights = dict()
        outputs = []
        if self.num_layers > 0:
            for i in range(self.num_layers):
                x = self.activation(self.conv[i](x, edge_index, edge_attr))
                x = F.dropout(x, self.dropout, training=self.training)
                outputs.append(x)
            x = torch.cat(outputs, dim=-1)
        else:
            x = self.activation(self.fc0_5(x))

        graph_pooling = time.time()
        #self-attention graph pooling
        # if self.pooling_type == "sagpool":
        #     x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        # elif self.pooling_type == "topk":
        #     x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        # else: 
        #     attn_weights['batch'] = batch

        if self.GCN_readout:
            if self.readout_type == "add":
                x = global_add_pool(x, attn_weights['batch'])
            elif self.readout_type == "mean":
                x = global_mean_pool(x, attn_weights['batch'])
            elif self.readout_type == "max":
                x = global_max_pool(x, attn_weights['batch'])
            else:
                raise NotImplementedError("readout type not implemented")
        else:
            agent_index = torch.cat((torch.Tensor([True]).to(self.device), (batch[1:] - batch[:-1]) > 0)).bool()
            x = x[agent_index]
        
        MRGCN_end = time.time()
        lstm_start = time.time()

        graph_embedding = self.fc1(x)

        translation_pred_vector, embedding_pred_vector, mode_probabilities, mu_values, logvar_values = self.motion_prediction(graph_embedding)

        lstm_end = time.time()
        gcn = time.time()
        # print(len(sequence))
        # print('TIMING-- total: {}, extract: {}, pretext: {}, mrgcn: {}'.format(gcn-start, extract-start, pretext-extract, gcn-pretext))
        # print('TIMING-- total: {}, extract: {}, pretext_end: {}, MRGCN:{}, Graph_pooling:{},  lstm:{}'\
        #         .format(gcn-start, extract-start, pretext_end-pretext_start, graph_pooling-MRGCN_start, MRGCN_end-graph_pooling, lstm_end-lstm_start))
        # pdb.set_trace()
        # return (translation_pred_vector, mode_probabilities, embedding_pred_vector, graph_embedding), \
        #     (pretext_output, torch.cat(pretext_gt, 0)), (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if self.probabilistic else None


        # embedding_pred_vector is not used any more
        return (translation_pred_vector, mode_probabilities, embedding_pred_vector, graph_embedding), \
            (None, None), (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if mu_values else None

        