import os

from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS

from b2b.consts import SPLIT_TRAIN, SPLIT_TEST
from b2b.utils.util import make_power_2
from PIL import Image
import pdb
import numpy as np
import imgaug.augmenters as iaa


class ImageDomainFolder(Dataset):
    """Dataset structure introduced in a CycleGAN paper.

    This dataset expects images to be arranged into subdirectories
    under `path`: `trainA`, `trainB`, `testA`, `testB`. Here, `trainA`
    subdirectory contains training images from domain "a", `trainB`
    subdirectory contains training images from domain "b", and so on.

    Parameters
    ----------
    path : str
        Path where the dataset is located.
    domain : str
        Choices: 'a', 'b'.
    split : str
        Choices: 'train', 'test', 'val'
    transform : Callable or None,
        Optional transformation to apply to images.
        E.g. torchvision.transforms.RandomCrop.
        Default: None
    """

    def __init__(
        self, path,
        domain        = 'a',
        split         = SPLIT_TRAIN,
        transform     = None,
        **kwargs
    ):
        super().__init__(**kwargs)

        subdir = split + domain.upper()

        self._path      = os.path.join(path, subdir)
        self._imgs      = ImageDomainFolder.find_images_in_dir(self._path)
        self._transform = transform
        self.split = split

    @staticmethod
    def find_images_in_dir(path):
        extensions = set(IMG_EXTENSIONS)

        result = []
        # print(os.getcwd())
        for fname in sorted(os.listdir(path)):
            fullpath = os.path.join(path, fname)

            if not os.path.isfile(fullpath):
                continue

            ext = os.path.splitext(fname)[1]
            if ext not in extensions:
                continue

            result.append(fullpath)

        result.sort()
        return result

    def __len__(self):
        return len(self._imgs)

    def __getitem__(self, index):
        path   = self._imgs[index]
        result = default_loader(path)

        size_img = result.size
        # if self.split == SPLIT_TEST:
        #     # breakpoint()
        #     # print(type(result))
        #     w, h = make_power_2(result.size[0], result.size[1], 256)
        #     result_clone = Image.new(result.mode, (w, h), (0,0,0))
        #     result_clone.paste(result, (0, 0))
        #     result = result_clone
            # print(result.size)
        # breakpoint()
        # print(type(result))

        # result = np.array(result)
        # # print(type(result))

        # # Define the motion blur augmentation
        # motion_blur = iaa.MotionBlur(k=15)  # Adjust the kernel size (k) and angle as per your requirement

        # # Apply the motion blur augmentation to the image
        # aug_result = motion_blur.augment_image(image=result)

        # result = Image.fromarray(aug_result)

        # print(result.shape)
        # pdb.set_trace()
        if self._transform is not None:
            result = self._transform(result)


        return {'img': result,
                'name': self._imgs[index],
                'ori_size': size_img}

