import torch
import torch.nn as nn
from torch_geometric.nn.pool import fps

def down_sample_query(query, pos, down_sample_rate):
    '''
    do down sampling on query with index of pos for fps sampling
    query: (N, B, D)
    pos: (B, N, 3)
    down_sample_rate: float
    return down sampled query: (N*down_sample_rate, B, D)
    '''
    N, B, D = query.shape
    query = query.permute(1, 0, 2)
    # Perform FPS to select points

    # Reshape the query and pos tensors for FPS
    pos_reshaped = pos.reshape(N*B, pos.size(2))

    # Perform FPS to select num_samples points for each batch
    batch_indices = torch.arange(B, device=pos.device).unsqueeze(1).repeat(1, N).view(-1)

    sampled_indices = fps(pos_reshaped.float(), batch_indices, ratio=down_sample_rate)
    num_samples = int(sampled_indices.shape[0] / B)
    # Reshape the sampled indices to match the original batch size
    sampled_indices_reshaped = sampled_indices.view(B, num_samples)
    query = query.reshape(N*B, query.size(2))
    # Extract the sampled query using the selected indices
    sampled_query = query[sampled_indices]

    # Reshape the sampled query tensor to (num_samples, B, 256)
    sampled_query = sampled_query.reshape(B, num_samples, D).permute(1, 0, 2)
    return sampled_query

if __name__ == '__main__':
    
    query = torch.randn((4, 1000, 256))
    down_sample_rate = 0.1


    pos = torch.randn((4, 1000, 3))
    q = down_sample_query(query, pos, down_sample_rate)
    print(q.shape)

    