import sys
import numpy as np
import os
from torch.utils.data import Dataset
import os
from utils.pc_utils import crop_processing
from scipy.spatial.distance import cdist

class SPDemoDataset(Dataset):
    def __init__(self, input_pts=None, centroid_xyz=None, num_skel_pts=512):
    
        # load input image
        self.input_pts = input_pts

        self.centroid_xyz = centroid_xyz
        self.num_skel_pts = num_skel_pts

    def __getitem__(self, index):

        input_img = np.zeros((self.num_skel_pts, 4))

        data_dict = {}

        data_dict['input'] = self.input_pts.astype(np.float32)

        data_dict = crop_processing(data_dict, num_near_pts=self.num_skel_pts, centroid_xyz=np.array(self.centroid_xyz))

        input_img[:, :] = data_dict['input']

        return input_img

    def __len__(self):
            return 1