import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv, FastRGCNConv


class FeatureRegressionMRGCN(nn.Module):
    '''
    defines feature regression pretext task. uses MRGCN+MLP for the pretext task.
    objective: given node embeddings, infer the original node feature list.
    '''
    def __init__(self, config):
        super(FeatureRegressionMRGCN, self).__init__()
        self.rgcn_func = FastRGCNConv if config.model_config['conv_type'] == "FastRGCNConv" else RGCNConv
        self.device = config.model_config['device']
        self.dropout = config.model_config['dropout']

        self.num_features = config.model_config['num_of_classes']
        self.num_relations = config.model_config['num_relations']
        self.hidden_dim = config.model_config['hidden_dim']
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu

        self.pretext_conv1 = self.rgcn_func(self.num_features, self.hidden_dim, self.num_relations).to(self.device)
        self.pretext_conv2 = self.rgcn_func(self.hidden_dim, self.hidden_dim, self.num_relations).to(self.device)
        self.pretext_fc1 = Linear(2*self.hidden_dim, 64)
        self.pretext_fc2 = Linear(64, self.num_features)


    #TODO: move the code for pseudo label generation to here instead.
    def generate_pseudo_labels(self, sequence):
        
        return sequence

    
    def forward(self, x, edge_index, edge_attr):
        edge_index = edge_index.to(self.device)
        pretext_x1 = self.activation(self.pretext_conv1(x, edge_index, edge_attr))
        pretext_x1 = F.dropout(pretext_x1, self.dropout, training=self.training)
        pretext_x2 = self.activation(self.pretext_conv2(pretext_x1, edge_index, edge_attr))
        pretext_x2 = F.dropout(pretext_x2, self.dropout, training=self.training)
        pretext_x = torch.cat([pretext_x1, pretext_x2], dim=-1) #concatenate features after each conv.
        pretext_x = self.activation(self.pretext_fc1(pretext_x))
        pretext_output = self.pretext_fc2(pretext_x) #no activation as this is a regression task
        return pretext_output