import os, pdb
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchnlp.nn import Attention
from torch.nn import Linear, LSTM, GRU
from torch_geometric.nn import RGCNConv, TopKPooling, FastRGCNConv, InnerProductDecoder
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from .rgcn_sag_pooling import RGCNSAGPooling
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from .pretext.feature_regression import FeatureRegression
from torch_geometric.utils.sparse import dense_to_sparse
import time
import math
import traceback

'''self-supervised implementation of our MRGCN.'''

class SSMRGCN(nn.Module):
    
    def __init__(self, config):
        super(SSMRGCN, self).__init__()
        self.probabilistic = config.training_config['probabilistic']
        self.num_features = config.model_config['num_of_classes']
        self.num_relations = config.model_config['num_relations']
        self.num_classes  = config.model_config['nclass']
        self.num_layers = config.model_config['num_layers'] #defines number of RGCN conv layers.
        self.hidden_dim = config.model_config['hidden_dim']
        self.layer_spec = None if config.model_config['layer_spec'] == None else list(map(int, config.model_config['layer_spec'].split(',')))
        self.lstm_dim = config.model_config['lstm_dim']
        # self.lstm_dim2 = config.model_config['lstm_dim']
        self.rgcn_func = FastRGCNConv if config.model_config['conv_type'] == "FastRGCNConv" else RGCNConv
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu
        self.pooling_type = config.model_config['pooling_type']
        self.readout_type = config.model_config['readout_type']
        self.temporal_type = config.model_config['temporal_type']
        self.device = config.model_config['device']
        self.dropout = config.model_config['dropout']
        self.num_lstm_layers = config.model_config['num_lstm_layers']
        self.GCN_readout = config.training_config['GCN_readout']
        self.load_lane_info = config.training_config['load_lane_info']
        self.temporal_model = config.model_config['temporal_model']
        self.enable_lstm_input_regression = config.training_config['enable_lstm_input_regression']
        self.proximity_thresholds = [['near_coll',4],['super_near',7],['very_near',10],['near',16],['visible',25]]
        
        print("load_lane_info:", self.load_lane_info)
        self.conv = []
        total_dim = 0

        self.fc0_5 = None
        if self.layer_spec == None:
            if self.num_layers > 0:
                self.conv.append(self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(self.device))
                total_dim += self.hidden_dim
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(self.device))
                    total_dim += self.hidden_dim
            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
        else:
            if self.num_layers > 0:
                print("using layer specification and ignoring hidden_dim parameter.")
                print("layer_spec: " + str(self.layer_spec))
                self.conv.append(self.rgcn_func(self.num_features, self.layer_spec[0], self.num_relations).to(self.device))
                total_dim += self.layer_spec[0]
                for i in range(1, self.num_layers):
                    self.conv.append(self.rgcn_func(self.layer_spec[i-1], self.layer_spec[i], self.num_relations).to(self.device))
                    total_dim += self.layer_spec[i]

            else:
                self.fc0_5 = Linear(self.num_features, self.hidden_dim)
                total_dim += self.hidden_dim

        if self.pooling_type == "sagpool":
            self.pool1 = RGCNSAGPooling(total_dim, self.num_relations, ratio=config.model_config['pooling_ratio'], rgcn_func=config.model_config['conv_type'])
        elif self.pooling_type == "topk":
            self.pool1 = TopKPooling(total_dim, ratio=config.model_config['pooling_ratio'])

        # if self.probabilitic:
        #     self.fc1 = nn.Sequential(
        #         Linear(total_dim, self.lstm_dim1 * 2),
        #         nn.ReLU(),
        #         Linear(self.lstm_dim1 * 2, self.lstm_dim1 * 2 if self.probabilitic else 1),
        #     )F
        # else:
        self.fc1 = nn.Sequential(
            Linear(total_dim, self.lstm_dim * 2),
            nn.ReLU(),
            Linear(self.lstm_dim * 2, self.lstm_dim),
            # nn.ReLU(),
            nn.Tanh()
        )

        if self.probabilistic:
            self.f3 = nn.Sequential(
                Linear(self.lstm_dim, self.lstm_dim * 2)
            )
        
        if self.temporal_model == 'lstm':
            self.lstm = LSTM(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        elif self.temporal_model == 'gru':
            self.lstm = GRU(self.lstm_dim, 
                            self.lstm_dim, 
                            num_layers=self.num_lstm_layers, 
                            batch_first=True)
        else:
            raise AttributeError("temporal model invalid")

        self.fc2 = nn.Sequential(
            Linear(self.lstm_dim, self.lstm_dim//2),
            nn.ReLU(),
            Linear(self.lstm_dim//2, 2),
        )

        #~~~~~~~~~~~~SS Encoder~~~~~~~~~~~~~~
        #node encoder
        self.feature_extractor1 = nn.Sequential(
            Linear(33, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features),
            nn.ReLU()
        )

        self.feature_extractor2 = nn.Sequential(
            Linear(self.num_features, self.hidden_dim),
            nn.ReLU(),
            Linear(self.hidden_dim, self.num_features * self.num_relations), 
        )

        self.lane_extractor = None
        if self.load_lane_info:
            self.lane_extractor = nn.Sequential(
                Linear(2, self.hidden_dim),
                nn.ReLU(),
                Linear(self.hidden_dim, self.num_features),
                nn.ReLU()
            )

        #edge encoder. takes in two node embeddings and returns multilabel edge selection.
        self.f2 = None
        # self.f2 = nn.Sequential(
        #     Linear(2 * 30, self.num_relations),
        #     nn.Sigmoid()
        # )

        #~~~~~~~~SS Pretext Model~~~~~~~~~~~~ #TODO: test performance with/without pretext
        self.enable_pretext = config.training_config['enable_pretext']
        self.pretext = None
        if self.enable_pretext:
            if config.model_config['pretext_model']:
                self.pretext = FeatureRegression(config)
            else:
                raise NotImplementedError
            self.pretext = self.pretext.to(self.device)

        self.get_model_size()


    def get_model_size(self):
        total_param_size = total_buffer_size = 0
        for model_name in [f'conv[{i}]' for i in range(self.num_layers)] + ['pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
        # for model_name in ['pool1', 'fc0_5', 'fc1', 'fc2', 'lstm', 'feature_extractor1', 'feature_extractor2', 'f2', 'pretext']:
            param_size, buffer_size = self.print_model_size(model_name)
            total_param_size += param_size
            total_buffer_size += buffer_size

        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in self.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(size_all_mb, param_mb, buffer_mb))
        corrected_param_mb = total_param_size / 1024**2
        corrected_buffer_mb = total_buffer_size / 1024**2
        corrected_size_all_mb = corrected_param_mb + corrected_buffer_mb
        print('model size corrected(?): {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(corrected_size_all_mb, corrected_param_mb, corrected_buffer_mb))


    def print_model_size(self, model_name):
        model = eval('self.' + model_name)
        if not model:
            return 0, 0
        print(model)
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        param_mb = param_size / 1024**2
        buffer_mb = buffer_size / 1024**2
        size_all_mb = param_mb + buffer_mb
        print('{}: model size: {:.3f}MB, param size: {:.3f}MB, buffer size: {:.3f}MB'.format(model_name, size_all_mb, param_mb, buffer_mb))
        return param_size, buffer_size

    @staticmethod
    def get_proximity_relations(thresholds, distances):
        relations = []
        for d in distances:
            for relation_ind in range(len(thresholds)):
                if d <= thresholds[relation_ind][1]:
                    relations.append(relation_ind)
                    break
                if relation_ind == len(thresholds) -1:
                    relations.append(relation_ind)
        return relations

    @staticmethod
    def get_euclidiean_distance(x1, y1, x2, y2):
        return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)


    def forward(self, state_sequence, lanes_info_sequence, pred_frames=12):
        #graph extraction component
        # pretext_output = torch.zeros(0).to(self.device)
        ####
        try:
            size_batch = 0
            graph_list = []
            pretext_gt = []
            for seq_ind in range(len(state_sequence)):
                seq = state_sequence[seq_ind]
                graph = dict()
                dist_list = []
                index_list_from = []
                index_list_to = []
                agent_info = seq[0,:]

                ##

                graph['node_embeddings'] = seq

                # if self.load_lane_info:
                #     lanes_info = lanes_info_sequence[seq_ind]
                #     lanes_embeddings_list = []
                #     for lane in lanes_info:
                #         lane_seg_embeddings = self.lane_extractor(lane)
                #         lane_embedding = torch.mean(lane_seg_embeddings, 0)
                #         lanes_embeddings_list.append(lane_embedding.unsqueeze(0))
                #     lanes_embeddings = torch.cat(lanes_embeddings_list, 0)
                #     node_embeddings = torch.cat([node_embeddings, lanes_embeddings], 0)

                # size_batch += node_embeddings.shape[0]
                # n_actors = node_embeddings.shape[0]
                # graph['node_embeddings'] = node_embeddings
                # if self.load_lane_info:
                #     pretext_gt.append(node_embeddings)
                # else:
                #     pretext_gt.append()

                # size_batch += node_feature_list.shape[0]
                # n_actors = node_feature_list.shape[0]
                # graph['node_embeddings'] = self.feature_extractor1(node_feature_list)

                ##

                for actor_ind in range(seq.shape[0]):
                    if actor_ind == 0:
                        continue
                    actor_info = seq[actor_ind,:]
                    dist = self.get_euclidiean_distance(actor_info[0], actor_info[1], agent_info[0], agent_info[1])
                    dist_list.append(dist)
                    dist_list.append(dist)
                    index_list_from.append(0)
                    index_list_from.append(actor_ind)
                    index_list_to.append(actor_ind)
                    index_list_to.append(0)
                graph['edge_attr'] = torch.Tensor(self.get_proximity_relations(self.proximity_thresholds, dist_list)).long()
                graph['edge_index'] = torch.vstack((torch.Tensor(index_list_from), torch.Tensor(index_list_to))).long()
                # graph['node_embeddings'] = instance_token_sequence[seq_ind]
                
                graph_list.append(graph)
        except:
            traceback.print_exc()
            pdb.set_trace()


        ####


        # graph_list = []
        # start = time.time()
        # # assert torch.count_nonzero(state_sequence[0][0, :2]) == 0, f"the inital translation should always be equal to 0, {torch.count_nonzero(state_sequence[0][0, :3])} instead. {state_sequence[0][:, :3]}"
        # # TODO: delete th
        # size_batch = 0
        # try:
        #     pretext_gt = []
        #     for i in range(len(state_sequence)):
        #         loop_start = time.time()
        #         graph = {}
        #         node_feature_list = state_sequence[i]
        #         node_embeddings = self.feature_extractor1(node_feature_list)

        #         if self.load_lane_info:
        #             lanes_info = lanes_info_sequence[i]
        #             lanes_embeddings_list = []
        #             for lane in lanes_info:
        #                 lane_seg_embeddings = self.lane_extractor(lane)
        #                 lane_embedding = torch.mean(lane_seg_embeddings, 0)
        #                 lanes_embeddings_list.append(lane_embedding.unsqueeze(0))
        #             lanes_embeddings = torch.cat(lanes_embeddings_list, 0)
        #             node_embeddings = torch.cat([node_embeddings, lanes_embeddings], 0)

        #         size_batch += node_embeddings.shape[0]
        #         n_actors = node_embeddings.shape[0]
        #         graph['node_embeddings'] = node_embeddings
                
        #         if self.load_lane_info:
        #             pretext_gt.append(node_embeddings)
        #         else:
        #             pretext_gt.append(node_feature_list)

        #         # size_batch += node_feature_list.shape[0]
        #         # n_actors = node_feature_list.shape[0]
        #         # graph['node_embeddings'] = self.feature_extractor1(node_feature_list)

        #         end_node_extraction = time.time()
        #         multi_embeddings = self.feature_extractor2(graph['node_embeddings']) \
        #                             .reshape(n_actors, self.num_features, self.num_relations) \
        #                             .transpose(1,2).transpose(0,1)
        #         multi_embeddings = F.normalize(multi_embeddings, dim=2)
        #         embedding_extraction = time.time()
        #         # adj_matrix = InnerProductDecoder().forward_all(multi_embeddings)
        #         adj_matrix = torch.matmul(multi_embeddings, multi_embeddings.transpose(1,2))
        #         # adj_matrix = F.sigmoid(adj_matrix)

        #         # check the computational graph is correct for adj matrix
        #         identity_mask = 1 - (torch.eye(adj_matrix.shape[-1]).unsqueeze(0).repeat(self.num_relations,1,1)).to(self.device)
        #         adj_matrix = adj_matrix * identity_mask
        #         adj_mask = torch.logical_or(torch.argmax(adj_matrix, dim=0), adj_matrix > 0.5)
        #         adj_matrix = adj_matrix * adj_mask.to(self.device)
        #         edge_index, edge_attr = dense_to_sparse(adj_matrix)
        #         edge_index = torch.remainder(edge_index, n_actors)
        #         edge_attr = edge_index[0].div(n_actors, rounding_mode="floor")
                
        #         graph['edge_index'], graph['edge_attr'] = edge_index, edge_attr
        #         # graph['edge_attr']: torch.Size([56424])
        #         # graph['edge_index']: torch.Size([2, 56424])
        #         relation_iteration = time.time()
        #         # pdb.set_trace()

        #         # combo_list = torch.stack([torch.tensor([u, v]) for u,v in itertools.combinations(range(len(node_feature_list)), 2)], dim=0)
        #         # relation_iteration = time.time()
        #         # # TODO efficiency
        #         # node_combinations = [torch.cat([node_feature_list[u], node_feature_list[v]], dim=0) for u,v in combo_list]
                
        #         # relation_concat = time.time()
        #         # node_combinations = torch.stack(node_combinations, dim=0)
        #         # edge_vectors = self.f2(node_combinations)
        #         # # edge_vectors = torch.sigmoid(edge_vectors) #sigmoid to generate multilabel conf. scores, then binarize. 
        #         # top_edges = torch.argmax(edge_vectors, dim=1) #get highest scoring edge.
        #         # graph['edge_index'] = torch.cat([combo_list, combo_list.flip(1)], dim=0)
        #         # graph['edge_attr'] = torch.cat([top_edges, top_edges], dim=0)
        #         # pos_idxs = edge_vectors > 0.5 #add all edge types that score >0.5
        #         # pos_idxs = pos_idxs.nonzero()
        #         # pos_edge_idx, pos_edge_attrs = combo_list[pos_idxs[:,0]], pos_idxs[:, 1]
        #         # graph['edge_index'] = torch.cat([graph['edge_index'], pos_edge_idx, pos_edge_idx.flip(1)], dim=0)
        #         # graph['edge_attr'] = torch.cat([graph['edge_attr'], pos_edge_attrs, pos_edge_attrs], dim=0)
        #         # graph['edge_index'] = torch.transpose(graph['edge_index'], 0, 1) 
        #         # graph['edge_attr'] = graph['edge_attr']
        #         graph_list.append(graph)
        #         end_loop = time.time()
        #         # print('TIMING in extaction-- node extraction: {:.4f}, graph_building: {:.4f}'\
        #         #     .format(end_node_extraction-loop_start, end_loop-end_node_extraction))
        # except:
        #     traceback.print_exc()
        #     pdb.set_trace()

        
        graph_data_list = [Data(x=g['node_embeddings'], edge_index=g['edge_index'], edge_attr=g['edge_attr']) for g in graph_list]
        train_loader = DataLoader(graph_data_list, batch_size=len(graph_data_list))
        sequence = next(iter(train_loader)).to(self.device)
        x, edge_index, edge_attr, batch = sequence.x, sequence.edge_index, sequence.edge_attr, sequence.batch
        extract = time.time()

        pretext_start = time.time()
        #pretext model
        pretext_output = None
        if self.enable_pretext:
            pretext_output = self.pretext(x, edge_index, edge_attr)
        pretext_end = time.time()
        MRGCN_start = time.time()
        #MRGCN component. downstream task 
        attn_weights = dict()
        outputs = []
        if self.num_layers > 0:
            for i in range(self.num_layers):
                x = self.activation(self.conv[i](x, edge_index, edge_attr))
                x = F.dropout(x, self.dropout, training=self.training)
                outputs.append(x)
            x = torch.cat(outputs, dim=-1)
        else:
            x = self.activation(self.fc0_5(x))

        #self-attention graph pooling
        if self.pooling_type == "sagpool":
            x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        elif self.pooling_type == "topk":
            x, edge_index, _, attn_weights['batch'], attn_weights['pool_perm'], attn_weights['pool_score'] = self.pool1(x, edge_index, edge_attr=edge_attr, batch=batch)
        else: 
            attn_weights['batch'] = batch

        if self.GCN_readout:
            if self.readout_type == "add":
                x = global_add_pool(x, attn_weights['batch'])
            elif self.readout_type == "mean":
                x = global_mean_pool(x, attn_weights['batch'])
            elif self.readout_type == "max":
                x = global_max_pool(x, attn_weights['batch'])
            else:
                pass
        else:
            agent_index = torch.cat((torch.Tensor([True]).to(self.device), (batch[1:] - batch[:-1]) > 0)).bool()
            x = x[agent_index]
        
        MRGCN_end = time.time()
    
        # if self.temporal_type == "mean":
        #     x = self.activation(self.fc1_5(x.mean(axis=0)))
        # elif self.temporal_type == "lstm_last":
        #     x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
        #     x = h.flatten()
        # elif self.temporal_type == "lstm_sum":
        #     x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
        #     x = x_predicted.sum(dim=1).flatten()
        # elif self.temporal_type == "lstm_attn":
        #     x_predicted, (h, c) = self.lstm(x.unsqueeze(0))
        #     x, attn_weights['lstm_attn_weights'] = self.attn(h.view(1,1,-1), x_predicted)
        #     x, (h_decoder, c_decoder) = self.lstm_decoder(x, (h, c))
        #     x = x.flatten()
        # elif self.temporal_type == "lstm_seq": #used for step-by-step sequence prediction. 
        #     x_predicted, (h, c) = self.lstm(x.unsqueeze(0)) #x_predicted is sequence of predictions for each frame, h is hidden state of last item, c is last cell state
        #     x = x_predicted.squeeze(0) #we return x_predicted as we want to know the output of the LSTM for each value in the sequence
        # else:
        #     pass

        lstm_start = time.time()

        graph_embedding = self.fc1(x)

        #TODO @barry: this entire section below here needs some comments to explain what's going on
        mu_values = []
        logvar_values = []
        # pdb.set_trace()
        if self.temporal_model == 'lstm':
            if self.probabilistic:
                x = self.f3(graph_embedding)
                mu = x[:, :self.lstm_dim]
                logvar = x[:, self.lstm_dim:]
                std = torch.exp(0.5 * logvar)
                eps = torch.randn_like(std)
                if self.training:
                    x = eps * std + mu
                else:
                    x = mu
                mu_values.append(mu)
                logvar_values.append(logvar)

                x_predicted, (h, c) = self.lstm(x)
                embedding_pred = [x_predicted]
                for frame in range(1, pred_frames):
                    # pdb.set_trace()
                    x = self.f3(h[-1].unsqueeze(0))
                    mu = x[:, :self.lstm_dim]
                    logvar = x[:, self.lstm_dim:]
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    if self.training:
                        x = eps * std + mu
                    else:
                        x = mu
                    mu_values.append(mu)
                    logvar_values.append(logvar)
                    x_predicted, (h, c) = self.lstm(x, (h, c))
                    embedding_pred.append(x_predicted)
                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)

            else:
                
                x_predicted, (h, c) = self.lstm(graph_embedding)
                embedding_pred = [x_predicted]
                for frame in range(1, pred_frames):
                    x_predicted, (h, c) = self.lstm(h[-1].unsqueeze(0), (h, c))
                    embedding_pred.append(x_predicted)
                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)
                
        elif self.temporal_model == 'gru':
            if self.probabilistic:
                x = self.f3(graph_embedding)
                mu = x[:, :self.lstm_dim]
                logvar = x[:, self.lstm_dim:]
                std = torch.exp(0.5 * logvar)
                eps = torch.randn_like(std)
                if self.training:
                    x = eps * std + mu
                else:
                    x = mu
                mu_values.append(mu)
                logvar_values.append(logvar)
                x_predicted, h = self.lstm(x)
                embedding_pred = [x_predicted]
                for frame in range(1, pred_frames):
                    x = self.f3(h[-1].unsqueeze(0))
                    mu = x[:, :self.lstm_dim]
                    logvar = x[:, self.lstm_dim:]
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    if self.training:
                        x = eps * std + mu
                    else:
                        x = mu
                    mu_values.append(mu)
                    logvar_values.append(logvar)
                    x_predicted, h = self.lstm(x, h)
                    embedding_pred.append(x_predicted)
                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)

            else:
                
                x_predicted, h = self.lstm(graph_embedding)
                embedding_pred = [x_predicted]
                for frame in range(1, pred_frames):
                    x_predicted, h = self.lstm(h[-1].unsqueeze(0), h)
                    embedding_pred.append(x_predicted)
                embedding_pred = torch.cat(embedding_pred, 0)
                translation_pred = self.fc2(embedding_pred)

        if self.enable_lstm_input_regression and self.probabilistic:
            embedding_pred = self.fc3(embedding_pred)[:, :self.lstm_dim]
        lstm_end = time.time()

        gcn = time.time()
        # print(len(sequence))
        # print('TIMING-- total: {}, extract: {}, pretext: {}, mrgcn: {}'.format(gcn-start, extract-start, pretext-extract, gcn-pretext))
        # print('TIMING-- total: {}, extract: {}, pretext_end: {}, MRGCN_end:{}, lstm_end:{}'\
        #         .format(gcn-start, extract-start, pretext_end-pretext_start, MRGCN_end-MRGCN_start, lstm_end-lstm_start))
        #pdb.set_trace()
        return (translation_pred, embedding_pred, graph_embedding), None, (torch.cat(mu_values, 0), torch.cat(logvar_values, 0)) if self.probabilistic else None
