import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F



class FeatureRegressionMLP(nn.Module):
    '''
    defines feature regression pretext task. uses MLP only for the pretext task.
    objective: given node embeddings, infer the original node feature list.
    '''
    def __init__(self, config):
        super(FeatureRegressionMLP, self).__init__()
        self.conv_func = Linear
        self.device = config.model_config['device']
        self.dropout = config.model_config['dropout']

        self.num_features = config.model_config['num_of_classes']
        self.num_relations = config.model_config['num_relations']
        self.hidden_dim = config.model_config['hidden_dim']
        self.activation = F.relu if config.model_config['activation'] == 'relu' else F.leaky_relu

        self.pretext_conv1 = Linear(self.num_features, self.hidden_dim).to(self.device)
        self.pretext_conv2 = Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        self.pretext_fc1 = Linear(self.hidden_dim, 64)
        self.pretext_fc2 = Linear(64, self.num_features)


    #TODO: move the code for pseudo label generation to here instead.
    def generate_pseudo_labels(self, sequence):

        return sequence

    
    def forward(self, x, edge_index, edge_attr):
        x = self.activation(self.pretext_conv1(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.activation(self.pretext_conv2(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.activation(self.pretext_fc1(x))
        x = self.pretext_fc2(x) #no activation as this is a regression task
        return x