import os
import cv2
import shutil
import torch
import pickle
import numpy as np
import PIL.Image as Image
import torch.utils.data as data
from torch.autograd import Variable
from keras.datasets import mnist, cifar10, cifar100
from keras.utils import np_utils
import time
import matplotlib.pyplot as plt
import itertools


class Dataset_CIFAR10(data.Dataset):
    train_list = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
    test_list = ['test_batch']
    
    def __init__(self, data_path, train=True, transform=None):
        if train:
            target_list = self.train_list
        else:
            target_list = self.test_list
        
        self.data = []
        self.labels = []
        for file_name in target_list:
            file_path = os.path.join(data_path, file_name)
            with open(file_path, 'rb') as fo:
                patch = pickle.load(fo, encoding='bytes')
            self.data.append(patch[b'data'])
            self.labels.extend(patch[b'labels'])
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
        self.transform = transform
    
    def __getitem__(self, index):
        img = Image.fromarray(self.data[index])
        label = self.labels[index]
        if not self.transform is None:
            img = self.transform(img)
        return img, label
    
    def __len__(self):
        return len(self.data)
    

def other_class(n_classes, current_class):
    """
    Returns a list of class indices excluding the class indexed by class_ind
    :param nb_classes: number of classes in the task
    :param class_ind: the class index to be omitted
    :return: one random class that != class_ind
    """
    if current_class < 0 or current_class >= n_classes:
        error_str = "class_ind must be within the range (0, nb_classes - 1)"
        raise ValueError(error_str)

    other_class_list = list(range(n_classes))
    other_class_list.remove(current_class)
    other_class = np.random.choice(other_class_list)
    return other_class


def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels


from numpy.testing import assert_array_almost_equal
def build_for_cifar100(size, noise):
    """ random flip between two random classes.
    """
    assert(noise >= 0.) and (noise <= 1.)

    P = np.eye(size)
    cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
    P[cls1, cls2] = noise
    P[cls2, cls1] = noise
    P[cls1, cls1] = 1.0 - noise
    P[cls2, cls2] = 1.0 - noise

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P


def multiclass_noisify(y, P, random_state=0):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """

    assert P.shape[0] == P.shape[1]
    assert np.max(y) < P.shape[0]

    # row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    m = y.shape[0]
    new_y = y.copy()
    flipper = np.random.RandomState(random_state)

    for idx in np.arange(m):
        i = y[idx]
        # draw a vector with only an 1
        flipped = flipper.multinomial(1, P[i, :], 1)[0]
        new_y[idx] = np.where(flipped == 1)[0]

    return new_y


NUM_CLASSES = {'mnist': 10, 'fashion': 10, 'cifar-10': 10, 'cifar-100': 100}
def get_data(dataset='mnist', noise_ratio=0, asym=False, random_shuffle=False):
    """
    Get training images with specified ratio of syn/ayn label noise
    """
    if dataset == 'cifar-10':
        (X_train, y_train), (X_test, y_test) = cifar10.load_data()

        X_train = X_train.reshape(-1, 32, 32, 3)
        X_test = X_test.reshape(-1, 32, 32, 3)
    elif dataset == 'cifar-100':
        # num_classes = 100
        (X_train, y_train), (X_test, y_test) = cifar100.load_data()

        X_train = X_train.reshape(-1, 32, 32, 3)
        X_test = X_test.reshape(-1, 32, 32, 3)
    elif dataset == 'mnist':
        (X_train, y_train), (X_test, y_test) = mnist.load_data()

        X_train = X_train.reshape(-1, 28, 28)
        X_test = X_test.reshape(-1, 28, 28)
    elif dataset == 'fashion':
        X_train, y_train = load_mnist('dataset/Fashion-MNIST', kind='train')
        X_test, y_test = load_mnist('dataset/Fashion-MNIST', kind='t10k')

        X_train = X_train.reshape(-1, 28, 28)
        X_test = X_test.reshape(-1, 28, 28)

        # writeable
        X_train = np.array(X_train)
        y_train = np.array(y_train)
        X_test = np.array(X_test)
        y_test = np.array(y_test)
    
    '''
    X_train = X_train / 255.0
    X_test = X_test / 255.0

    means = X_train.mean(axis=0)
    # std = np.std(X_train)
    X_train = (X_train - means)  # / std
    X_test = (X_test - means)  # / std

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    '''
    
    # they are 2D originally in cifar
    y_train = y_train.ravel()
    y_test = y_test.ravel()

    y_train_clean = np.copy(y_train)
    # generate random noisy labels
    if noise_ratio > 0:
        print(noise_ratio)
        if asym:
            data_file = "data/asym_%s_train_labels_%s.npy" % (dataset, noise_ratio)
            if dataset == 'cifar-100':
                P_file = "data/asym_%s_P_value_%s.npy" % (dataset, noise_ratio)
        else:
            data_file = "data/%s_train_labels_%s.npy" % (dataset, noise_ratio)
        if os.path.isfile(data_file):
            y_train = np.load(data_file)
            if dataset == 'cifar-100' and asym:
                P = np.load(P_file)
        else:
            if asym:
                if dataset == 'mnist':
                    # 1 < - 7, 2 -> 7, 3 -> 8, 5 <-> 6
                    source_class = [7, 1, 3, 5, 6, 4]
                    target_class = [1, 7, 5, 3, 4, 6]
                elif dataset == 'cifar-10':
                    # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
                    source_class = [9, 1, 2, 0, 3, 5, 4, 7]
                    target_class = [1, 9, 0, 2, 5, 3, 7, 4]
                elif dataset == 'fashion':
                    # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
                    source_class = [0, 6, 2, 4, 7, 9]
                    target_class = [6, 0, 4, 2, 9, 7]

                elif dataset == 'cifar-100':
                        P = np.eye(NUM_CLASSES[dataset])
                        n = noise_ratio/100.0
                        nb_superclasses = 20
                        nb_subclasses = 5

                        if n > 0.0:
                            for i in np.arange(nb_superclasses):
                                init, end = i * nb_subclasses, (i+1) * nb_subclasses
                                P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)

                            y_train_noisy = multiclass_noisify(y_train, P=P,
                                                               random_state=0)
                            actual_noise = (y_train_noisy != y_train).mean()
                            assert actual_noise > 0.0
                            y_train = y_train_noisy
                        np.save(P_file, P)

                else:
                    print('Asymmetric noise is not supported now for dataset: %s' % dataset)
                    return
                if dataset == 'mnist' or dataset == 'cifar-10' or dataset == 'fashion':
                    for s, t in zip(source_class, target_class):
                        cls_idx = np.where(y_train_clean == s)[0]
                        n_noisy = int(noise_ratio * cls_idx.shape[0] / 100)
                        noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
                        y_train[noisy_sample_index] = t

            else:
                n_samples = y_train.shape[0]
                n_noisy = int(noise_ratio * n_samples / 100)
                class_index = [np.where(y_train_clean == i)[0] for i in range(NUM_CLASSES[dataset])]
                class_noisy = int(n_noisy / NUM_CLASSES[dataset])

                noisy_idx = []
                for d in range(NUM_CLASSES[dataset]):
                    noisy_class_index = np.random.choice(class_index[d], class_noisy, replace=False)
                    noisy_idx.extend(noisy_class_index)

                for i in noisy_idx:
                    y_train[i] = other_class(n_classes=NUM_CLASSES[dataset], current_class=y_train[i])
            np.save(data_file, y_train)

    if random_shuffle:
        # random shuffle
        idx_perm = np.random.permutation(X_train.shape[0])
        X_train, y_train, y_train_clean = X_train[idx_perm], y_train[idx_perm], y_train_clean[idx_perm]

    # one-hot-encode the labels
    y_train_clean = np_utils.to_categorical(y_train_clean, NUM_CLASSES[dataset])
    y_train = np_utils.to_categorical(y_train, NUM_CLASSES[dataset])
    y_test = np_utils.to_categorical(y_test, NUM_CLASSES[dataset])

    print("X_train:", X_train.shape)
    print("y_train:", y_train.shape)
    print("X_test:", X_test.shape)
    print("y_test", y_test.shape)

    return X_train, y_train, y_train_clean, X_test, y_test


def CM_print(dataset, noise_ratio, asym):
    if dataset == 'cifar-10':
        classes_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 
    elif dataset == 'mnist':
        classes_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    elif dataset == 'fashion':
        classes_name = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    _, y_train, y_train_clean, _, _ = get_data(dataset, noise_ratio=noise_ratio, asym=asym, random_shuffle=False)
    y_train = np.argmax(y_train, axis=-1)
    y_train_clean = np.argmax(y_train_clean, axis=-1)
    num_class = len(classes_name)
    cm = np.zeros((num_class, num_class))
    for i, j in itertools.product(range(num_class), range(num_class)):
        cm[i, j] = ((y_train_clean==i) & (y_train==j)).sum().astype(np.float)/(y_train_clean==i).sum()
    print(cm)
    plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues, vmin=0.0, vmax=1.0)
    ifasym = 'asymmetric' if asym else 'symmetric'
    plt.title('%s with %s %s noise' % (dataset, noise_ratio, ifasym))
    plt.colorbar()
    tick_marks = np.arange(num_class)
    plt.xticks(tick_marks, classes_name, rotation=90)
    plt.yticks(tick_marks, classes_name)

    plt.axis("equal")

    ax = plt.gca()
    left, right = plt.xlim()
    ax.spines['left'].set_position(('data', left))
    ax.spines['right'].set_position(('data', right))
    for edge_i in ['top', 'bottom', 'right', 'left']:
        ax.spines[edge_i].set_edgecolor("white")
        
    thresh = 0.5
    for i, j in itertools.product(range(num_class), range(num_class)):
        num = cm[i, j]
        plt.text(j, i, '{:.2f}'.format(num),
                 verticalalignment='center',
                 horizontalalignment="center",
                 color="white" if num > thresh else "black")
    
    plt.ylabel('Clean Labels')
    plt.xlabel('Flipped Labels')
    
    plt.tight_layout()
    save_path = "data/%s_train_labels_%s_asym.png" % (dataset, noise_ratio) if asym else "data/%s_train_labels_%s_sym.png" % (dataset, noise_ratio)
    plt.savefig(save_path, transparent=True, dpi=800)
    

class ImgDataset(data.Dataset):
    def __init__(self, data, targets_noisy, targets_true=None, require_index=False, select_list=None, transform=None):
        if transform is None:
            self.data = torch.tensor(data)
        else:
            self.data = torch.tensor(np.array([transform(Image.fromarray(d)).numpy() for d in data]))
        self.targets_noisy = torch.LongTensor(targets_noisy)
        self.targets_true = None if targets_true is None else torch.LongTensor(targets_true)
        if not select_list is None:
            select_list = np.array(select_list)
            self.data = self.data[select_list]
            self.targets_noisy = self.targets_noisy[select_list]
            self.targets_true = self.targets_true[select_list]
        self.require_index = require_index
        self.transform = transform

    # target update
    def targets_transfer(self, index, rectified_targets):
        assert len(index)==rectified_targets.shape[0], 'Target amount error.'
        self.targets_noisy[index] = rectified_targets

    def save_targets(self, path):
        np.save(path, self.targets_noisy.numpy())

    # img update
    def img_update(self, index, new_imgs):
        assert len(index)==new_imgs.shape[0], 'Imgs amount error.'
        new_imgs[new_imgs>1.0] = 1.0
        new_imgs[new_imgs<0.0] = 0.0
        self.data[index] = new_imgs

    def __getitem__(self, index):
        x = self.data[index]
        y_n = self.targets_noisy[index]
        if not self.targets_true is None: y_t = self.targets_true[index]

        output = [x, y_n]
        if not self.targets_true is None:
            output.append(y_t)
        if self.require_index:
            output.append(index)
        return output

    def __len__(self):
        return len(self.data)


class MyDataset(data.Dataset):
    def __init__(self, data, targets_noisy, targets_true=None, require_index=False, select_list=None, dup=True, transform=None, max_epoch=200):
        self.data = data
        self.targets_noisy = torch.LongTensor(targets_noisy)
        self.targets_init = torch.LongTensor(targets_noisy).clone()
        self.targets_noisy_seq = self.targets_noisy.clone().float()
        if targets_true is None:
            self.targets_true = None
        else: 
            self.targets_true = torch.LongTensor(targets_true)
            self.noise_idx = self.targets_noisy.argmax(dim=1) != self.targets_true.argmax(dim=1)
        if not select_list is None:
            select_list = np.array(select_list)
            #select_list = np.hstack((np.random.choice(select_list, 18000, replace=False), np.setdiff1d(np.arange(50000), np.load('result/model/clean_idx.npy'))))
            if dup: select_list = np.hstack((select_list, np.random.choice(select_list, 50000 - len(select_list), replace=False)))
            self.data = self.data[select_list]
            self.targets_noisy = self.targets_noisy[select_list]
            self.targets_noisy_seq = self.targets_noisy_seq[0, select_list]
            self.targets_true = None if targets_true is None else self.targets_true[select_list]
            print('Data num', len(set(select_list)))
            print('Noise rate:', (np.argmax(self.targets_true, axis=-1) != np.argmax(self.targets_noisy, axis=-1)).float().mean())
        self.require_index = require_index
        self.transform = transform
        self.random = False

    # target update
    def targets_update(self, index, rectified_targets):
        assert len(index)==rectified_targets.shape[0], 'Target amount error.'
        self.targets_noisy[index] = rectified_targets
        
    # target update
    def targets_seq_update(self, index, rectified_targets, n_class=10):
        if len(rectified_targets.shape) == 1:
            rectified_targets = torch.eye(n_class)[rectified_targets].long()
        rectified_targets[rectified_targets<0] = 0
        self.targets_noisy_seq[index] += rectified_targets.float()

    # target ramdom noise
    def update_ramdom_noise(self, noise_ratio=1.0, clean_ratio=0.0):
        self.targets_noisy = self.targets_init.clone()
        noise_idx = self.noise_idx & (torch.rand(self.noise_idx.shape) < noise_ratio)
        clean_idx = (~self.noise_idx) & (torch.rand(self.noise_idx.shape) < clean_ratio)
        random_idx = noise_idx + clean_idx
        N = random_idx.sum()
        K = self.targets_true.shape[1]
        new_noise = torch.randint(0, K, (N,))
        self.targets_noisy[random_idx] = torch.eye(K)[new_noise].long()
        
    def save_targets(self, path):
        print('Save targets, acc:', (torch.argmax(self.targets_noisy, dim=-1)==torch.argmax(self.targets_true, dim=-1)).float().mean())
        np.save(path, self.targets_noisy.numpy())

    # img update
    def img_update(self, index, new_imgs):
        assert len(index)==new_imgs.shape[0], 'Imgs amount error.'
        new_imgs = (new_imgs * 255).astype(np.uint8)
        new_imgs[new_imgs>255] = 255
        new_imgs[new_imgs<0] = 0
        self.data[index] = new_imgs

    def __getitem__(self, index):
        x = self.data[index]
        if self.random:
            targets_seq = torch.tensor(list(itertools.accumulate(self.targets_noisy_seq[index])))
            rand_threshold = torch.rand(1) * targets_seq[-1]
            y_n = torch.where(targets_seq > rand_threshold)[0].min()
        else:
            y_n = self.targets_noisy[index].argmax()
        if not self.targets_true is None: y_t = self.targets_true[index]

        if self.transform:
            x = Image.fromarray(self.data[index])
            x = self.transform(x)
        output = [x, y_n]
        if not self.targets_true is None:
            output.append(y_t)
        if self.require_index:
            output.append(index)
        return output

    def __len__(self):
        return len(self.data)


class CoDataset(data.Dataset):
    def __init__(self, data, targets_noisy, targets_true=None, require_index=False, transform=None):
        self.data = data
        self.targets_noisy = torch.LongTensor(targets_noisy)
        self.targets_init = torch.LongTensor(targets_noisy).clone()
        self.targets_seq_a = self.targets_noisy.clone().float()
        self.targets_seq_b = self.targets_noisy.clone().float()
        if targets_true is None:
            self.targets_true = None
        else: 
            self.targets_true = torch.LongTensor(targets_true)
            self.noise_idx = self.targets_noisy.argmax(dim=1) != self.targets_true.argmax(dim=1)
        self.require_index = require_index
        self.transform = transform
        self.random = False

    # target update
    def targets_update(self, index, rectified_targets):
        assert len(index)==rectified_targets.shape[0], 'Target amount error.'
        self.targets_noisy[index] = rectified_targets
        
    # target update
    def targets_seq_update(self, seq_idx, index, rectified_targets, n_class=10):
        if seq_idx == 0:
            targets_noisy_seq = self.targets_seq_a
        elif seq_idx == 1:
            targets_noisy_seq = self.targets_seq_b
        if len(rectified_targets.shape) == 1:
            rectified_targets = torch.eye(n_class)[rectified_targets]
        rectified_targets[rectified_targets<0] = 0
        targets_noisy_seq[index] += rectified_targets.float()

    # target ramdom noise
    def update_ramdom_noise(self, noise_ratio=1.0, clean_ratio=0.0):
        self.targets_noisy = self.targets_init.clone()
        noise_idx = self.noise_idx & (torch.rand(self.noise_idx.shape) < noise_ratio)
        clean_idx = (~self.noise_idx) & (torch.rand(self.noise_idx.shape) < clean_ratio)
        random_idx = noise_idx + clean_idx
        N = random_idx.sum()
        K = self.targets_true.shape[1]
        new_noise = torch.randint(0, K, (N,))
        self.targets_noisy[random_idx] = torch.eye(K)[new_noise].long()
        
    def save_targets(self, path):
        print('Save targets, acc:', (torch.argmax(self.targets_noisy, dim=-1)==torch.argmax(self.targets_true, dim=-1)).float().mean())
        np.save(path, self.targets_noisy.numpy())

    # img update
    def img_update(self, index, new_imgs):
        assert len(index)==new_imgs.shape[0], 'Imgs amount error.'
        new_imgs = (new_imgs * 255).astype(np.uint8)
        new_imgs[new_imgs>255] = 255
        new_imgs[new_imgs<0] = 0
        self.data[index] = new_imgs

    def __getitem__(self, index):
        x = self.data[index]
        if self.random:
            seq_a = torch.tensor(list(itertools.accumulate(self.targets_seq_a[index])))
            seq_b = torch.tensor(list(itertools.accumulate(self.targets_seq_b[index])))
            rand_threshold_a = torch.rand(1) * seq_a[-1]
            rand_threshold_b = torch.rand(1) * seq_b[-1]
            y_n_a = torch.where(seq_a > rand_threshold_a)[0].min()
            y_n_b = torch.where(seq_b > rand_threshold_b)[0].min()
        else:
            y_n_a = y_n_b = self.targets_noisy[index].argmax()
        if not self.targets_true is None: y_t = self.targets_true[index]

        if self.transform:
            x = Image.fromarray(self.data[index])
            x = self.transform(x)
        output = [x, y_n_a, y_n_b]
        if not self.targets_true is None:
            output.append(y_t)
        if self.require_index:
            output.append(index)
        return output

    def __len__(self):
        return len(self.data)

'''
class MyDataset(data.Dataset):
    def __init__(self, data, targets_noisy, targets_true=None, require_index=False, select_list=None, dup=True, transform=None, max_epoch=200, branch_num=1):
        self.data = data
        self.targets_noisy = torch.LongTensor(targets_noisy)
        self.targets_noisy_seq = [[self.targets_noisy.clone()]] * branch_num
        self.targets_true = None if targets_true is None else torch.LongTensor(targets_true)
        if not select_list is None:
            select_list = np.array(select_list)
            #select_list = np.hstack((np.random.choice(select_list, 18000, replace=False), np.setdiff1d(np.arange(50000), np.load('result/model/clean_idx.npy'))))
            if dup: select_list = np.hstack((select_list, np.random.choice(select_list, 50000 - len(select_list), replace=False)))
            self.data = self.data[select_list]
            self.targets_noisy = self.targets_noisy[select_list]
            self.targets_noisy_seq = self.targets_noisy_seq[0, 0, select_list]
            self.targets_true = None if targets_true is None else self.targets_true[select_list]
            print('Data num', len(set(select_list)))
            print('Noise rate:', (np.argmax(self.targets_true, axis=-1) != np.argmax(self.targets_noisy, axis=-1)).float().mean())
        self.require_index = require_index
        self.transform = transform

    # target update
    def targets_update(self, index, rectified_targets):
        assert len(index)==rectified_targets.shape[0], 'Target amount error.'
        self.targets_noisy[index] = rectified_targets
        
    # target update
    def targets_seq_update(self, epoch, index, rectified_targets, target_branch=0):
        assert len(index)==rectified_targets.shape[0], 'Target amount error.'
        while len(self.targets_noisy_seq[target_branch]) <= epoch:
            self.targets_noisy_seq[target_branch].append(self.targets_noisy_seq[target_branch][-1].clone())
        self.targets_noisy_seq[target_branch][-1][index] = rectified_targets
        
    def save_targets(self, path):
        print('Save targets, acc:', (torch.argmax(self.targets_noisy, dim=-1)==torch.argmax(self.targets_true, dim=-1)).float().mean())
        np.save(path, self.targets_noisy.numpy())

    # img update
    def img_update(self, index, new_imgs):
        assert len(index)==new_imgs.shape[0], 'Imgs amount error.'
        new_imgs = (new_imgs * 255).astype(np.uint8)
        new_imgs[new_imgs>255] = 255
        new_imgs[new_imgs<0] = 0
        self.data[index] = new_imgs

    def __getitem__(self, index):
        x = self.data[index]
        branch_num = len(self.targets_noisy_seq)
        seq_idx = np.random.randint(len(self.targets_noisy_seq[0]))
        y_n = [self.targets_noisy_seq[branch_idx][seq_idx][index] for branch_idx in range(branch_num)]
        if not self.targets_true is None: y_t = self.targets_true[index]

        if self.transform:
            x = Image.fromarray(self.data[index])
            x = self.transform(x)
        output = [x, *y_n]
        if not self.targets_true is None:
            output.append(y_t)
        if self.require_index:
            output.append(index)
        return output

    def __len__(self):
        return len(self.data)
'''


class ClothingDataset(data.Dataset):
    def __init__(self, path, transform=None):
        img_list = [os.listdir(path+str(i)) for i in range(14)]
        patch_file_list = []
        label_list = []
        for label in range(14):
            for file in img_list[label]:
                if file.endswith('.jpg'):
                    patch_file_list.append(os.path.join(path, str(label), file))
                    label_list.append(label)
        self.patch_file_list = np.array(patch_file_list)
        self.label_list = np.array(label_list)
        self.transform = transform
            
    def __getitem__(self, index):
        label = self.label_list[index]
        patch_file = self.patch_file_list[index]
        patch_img = Image.open(patch_file).convert('RGB')
        if self.transform is not None:
            patch_img = self.transform(patch_img)
        return patch_img, label, index
    
    def __len__(self):
        return len(self.patch_file_list)
