import torch
import torch.nn as nn
import torch.nn.functional as F

class SamplingNet(nn.Module):
    def __init__(self, args):
        super(SamplingNet, self).__init__()
        self.args = args
        self.num_samples = args.max_points # K=5000

        # Feature Extractor for Depth and Color
        self.depth_encoder = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)  # [B, 16, H, W]
        self.color_encoder = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # [B, 16, H, W]

        # Importance Score Network (MLP)
        self.score_mlp = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=1)  # [B, 1, H, W] -> Score Map
        )

    def forward(self, depth, color, mask, grd_uv_flat, grd_xyz_flat):
        B, _, H, W = depth.shape
        N = H * W  # Total number of points

        depth = depth / self.args.max_depth

        # Step 1: Feature Extraction
        depth_feat = F.relu(self.depth_encoder(depth))  # [B, 16, H, W]
        color_feat = F.relu(self.color_encoder(color))  # [B, 16, H, W]

        # Concatenate depth and color features
        fused_feat = torch.cat([depth_feat, color_feat], dim=1)  # [B, 32, H, W]

        # Compute Importance Scores
        score_map = self.score_mlp(fused_feat) # [B, 1, H, W]

        if self.args.sample == 'superpoint':
            topk_indices = self.extract_keypoints(score_map, self.args.max_points, start_ratio=0)
            sampled_grd_uv = torch.gather(grd_uv_flat, 1, topk_indices.unsqueeze(-1).expand(-1, -1, 2))  # [B, K, 2]
            sampled_grd_xyz = torch.gather(grd_xyz_flat, 1, topk_indices.unsqueeze(-1).expand(-1, -1, 3))  # [B, K, 3]

        else:
            # Apply mask (ensure we sample only from valid regions)
            score_map = score_map.masked_fill(mask.squeeze(1) == 0, -float('inf'))  # [B, H, W]

            # Flatten for Gumbel-Softmax
            score_map_flat = score_map.view(B, -1)  # [B, N]

            # Step 2: Gumbel-Softmax Sampling
            gumbel_noise = -torch.log(-torch.log(torch.rand_like(score_map_flat)))  # Gumbel Noise
            tau = 0.1  # Temperature parameter
            logits = (score_map_flat + gumbel_noise) / tau
            logits = logits - logits.max(dim=-1, keepdim=True)[0]
            probs = F.softmax(logits, dim=-1)  # [B, N]

            # Get Top-K indices
            topk_indices = torch.topk(probs, self.num_samples, dim=-1)[1]  # [B, K]

            # Step 3: Sample grd_uv and grd_xyz
            # grd_uv_flat = grd_uv.view(B, N, 2)  # [B, N, 2]
            # grd_xyz_flat = grd_xyz.view(B, N, 3)  # [B, N, 3]

            sampled_grd_uv = torch.gather(grd_uv_flat, 1, topk_indices.unsqueeze(-1).expand(-1, -1, 2))  # [B, K, 2]
            sampled_grd_xyz = torch.gather(grd_xyz_flat, 1, topk_indices.unsqueeze(-1).expand(-1, -1, 3))  # [B, K, 3]

        return sampled_grd_uv, sampled_grd_xyz, topk_indices

    def extract_keypoints(self, confidence, topk=256):
        """extrac key ponts from confidence map.
        Args:
            confidence: torch.Tensor with size (B,C,H,W).
            topk: extract topk points each confidence map
        Returns:
            A torch.Tensor of index where the key points are.(conf
        """

        radius = 1
        # fast Non-maximum suppression to remove nearby points
        def max_pool(x):
            return torch.nn.functional.max_pool2d(x, kernel_size=radius * 2 + 1, stride=1, padding=radius)

        max_mask = (confidence == max_pool(confidence))
        for _ in range(2):
            supp_mask = max_pool(max_mask.float()) > 0
            supp_confidence = torch.where(supp_mask, torch.zeros_like(confidence), confidence)
            new_max_mask = (supp_confidence == max_pool(supp_confidence))
            max_mask = max_mask | (new_max_mask & (~supp_mask))
        confidence = torch.where(max_mask, confidence, torch.zeros_like(confidence))

        # remove borders
        border = radius
        confidence[:, :, :border] = 0.
        confidence[:, :, -border:] = 0.
        confidence[:, :, :, :border] = 0.
        confidence[:, :, :, -border:] = 0.

        # confidence topk
        _, index = confidence.flatten(1).topk(topk, dim=1, largest=True, sorted=True)

        # index_v = torch.div(index, confidence.size(-1), rounding_mode='trunc')
        # index_u = index % confidence.size(-1)
        # # back to original index
        # index_v += start_H
        #
        # return torch.cat([index_u.unsqueeze(-1), index_v.unsqueeze(-1)], dim=-1)
        return index