import torch
import torch.nn as nn
import torch.nn.functional as F

"""
    Aim: To use a MLP-based proposal module to predict the coordinate offsets, objectnss score and radius for each neuron point
"""

class ProposalModule(nn.Module):
    def __init__(self, in_channels=1216, out_channels=2):
        super().__init__() 

        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels,512,1),
            nn.BatchNorm1d(512, track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(512,256,1),
            nn.BatchNorm1d(256, track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(256,128,1),
            nn.BatchNorm1d(128, track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        self.conv4 = nn.Conv1d(128,out_channels,1)


    def forward(self, x, end_points):
        # --------- PROPOSAL GENERATION ---------
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        out = self.conv4(x)

        out_transposed = out.transpose(2,1)
        end_points['objectness_scores'] = out_transposed[:,:,0:2]

        # SRNet
        # print("with noise enhancement")
        # cls_norm = torch.softmax(end_points['objectness_scores'], -1)
        # ratio = cls_norm[..., [0]] # noise need to be moved far

        end_points['radius'] = out_transposed[:,:,2]
        end_points['offsets'] = out_transposed[:,:,3:]
        end_points['center'] = end_points['input_xyz'] + end_points['offsets']

        return end_points