from __future__ import print_function, division
import os
import torch
from skimage import io
import numpy as np
from torch.utils.data import Dataset
import glob

class pachesDataset(Dataset):
    """Paches dataset with and without text on them"""

    def __init__(self,
                 root_dir_without_text,
                 root_dir_with_text,
                 root_dir_with_text_clean,
                 root_dir_gradCAM,
                 data_dir_only_text,
                 stage,
                 transform=None):
        """
        Args:
            root_dir_without_text (string): Path to the images without text on them.
            root_dir_with_text (string): Path to the images with text on them.
            root_dir_with_text_clean (string): Path to images without text that equal to the patches with the text.
            root_dir_gradCAM : Path to the gradCAM images (those images fit the root_dir_with_text after passing in the classifier)
            data_dir_only_text : patches with black background and random text on them.
            stage : can be 'train', 'val' or 'test'.
        """
        self.root_dir_without_text = root_dir_without_text + '_' + stage
        self.root_dir_with_text = root_dir_with_text + '_' + stage
        self.root_dir_with_text_clean = root_dir_with_text_clean + '_' + stage
        self.root_dir_gradCAM = root_dir_gradCAM + '_' + stage
        self.data_dir_only_text = data_dir_only_text + '_' + stage
        self.stage = stage
        self.transform = transform

        # Generate the data (lists of names of the images in the data sets - with/without text)
        self.list_without_text = (glob.glob(os.path.join(self.root_dir_without_text, '*.jpeg'))
                                  + glob.glob(os.path.join(self.root_dir_without_text, '*.png')))
        self.list_with_text = (glob.glob(os.path.join(self.root_dir_with_text, '*.jpeg'))
                               + glob.glob(os.path.join(self.root_dir_with_text, '*.png')))
        self.list_with_text_clean = (glob.glob(os.path.join(self.root_dir_with_text_clean, '*.jpeg'))
                                     + glob.glob(os.path.join(self.root_dir_with_text_clean, '*.png')))
        self.list_gradCAM = (glob.glob(os.path.join(self.root_dir_gradCAM, '*.jpeg'))
                             + glob.glob(os.path.join(self.root_dir_gradCAM, '*.png')))
        self.list_only_text = (glob.glob(os.path.join(self.data_dir_only_text, '*.jpeg'))
                               + glob.glob(os.path.join(self.data_dir_only_text, '*.png')))

        # sort the list in the same order (images with text and the clean ones have the same names)
        self.list_with_text.sort()
        self.list_with_text_clean.sort()
        self.list_gradCAM.sort()

    def __len__(self):
        return min(len(self.list_with_text_clean), len(self.list_with_text), len(self.list_without_text), len(self.list_only_text), len(self.list_gradCAM))

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # read 5 images - image with text and the correlated clean image and a different image without text.
        # read the gradCAM image and an image with black background and text on top.
        # the images are numpy arrays
        # read the image as a gray scale image save the normalization
        image_with_text = io.imread(self.list_with_text[idx], as_gray=True)
        image_with_text_clean = io.imread(self.list_with_text_clean[idx], as_gray=True)
        image_gradCAM = io.imread(self.list_gradCAM[idx], as_gray=True)
        image_without_text = io.imread(self.list_without_text[idx], as_gray=True)
        image_only_text = io.imread(self.list_only_text[idx], as_gray=True)  # already normalized

        image_with_text = image_with_text.astype(np.float32)  # change the type to float for the normalization
        image_with_text_clean = image_with_text_clean.astype(np.float32)
        image_gradCAM = image_gradCAM.astype(np.float32)
        image_without_text = image_without_text.astype(np.float32)
        image_only_text = image_only_text.astype(np.float32)

        # normalization
        if np.max(image_with_text) > 1:
            image_with_text = image_with_text/255.0
        if np.max(image_with_text_clean) > 1:
            image_with_text_clean = image_with_text_clean/255.0
        if np.max(image_gradCAM) > 1:
            image_gradCAM = image_gradCAM/255.0
        if np.max(image_without_text) > 1:
            image_without_text = image_without_text/255.0
        if np.max(image_only_text) > 1:
            image_only_text = image_only_text / 255.0

        image_with_text = np.expand_dims(image_with_text, axis=0)  # add the channel to the BW image
        image_with_text_clean = np.expand_dims(image_with_text_clean, axis=0)
        image_gradCAM = np.expand_dims(image_gradCAM, axis=0)
        image_without_text = np.expand_dims(image_without_text, axis=0)
        image_only_text = np.expand_dims(image_only_text, axis=0)

        image_with_text = torch.from_numpy(image_with_text)  # switch to tensor
        image_with_text_clean = torch.from_numpy(image_with_text_clean)
        image_gradCAM = torch.from_numpy(image_gradCAM)
        image_without_text = torch.from_numpy(image_without_text)
        image_only_text = torch.from_numpy(image_only_text)

        if self.transform:
            # transform all images in the same way so the PSNR and SSIM value will be correct
            trans_cat = torch.cat((image_with_text, image_with_text_clean, image_gradCAM), 0)
            trans_cat = self.transform(trans_cat)
            # separated to images
            image_with_text = trans_cat[0]
            image_with_text_clean = trans_cat[1]
            image_gradCAM = trans_cat[2]
            # add the channel to those images
            image_with_text = torch.unsqueeze(image_with_text, 0)
            image_with_text_clean = torch.unsqueeze(image_with_text_clean, 0)
            image_gradCAM = torch.unsqueeze(image_gradCAM, 0)

            image_without_text = self.transform(image_without_text)
            image_only_text = self.transform(image_only_text)

        # concat the gradCAM image and the one with the text so they enter the GAN together as 2 channels
        image_gradCAM_cat = torch.cat((image_with_text, image_gradCAM), 0)

        # in each sample there are 5 images -
        # 1. with text
        # 2. the same one without text
        # 3. the gradCAM image of the image with text from 1. this variable has 2 channels - one is the image with text
        #    and the second channel is the gradCAM img.
        # 4. another image without text
        # 5. an image with black background and only text
        # text (different frame than the first two ones)
        sample = {'image_with_text': image_with_text,
                  'image_with_text_clean': image_with_text_clean,
                  'image_gradCAM_cat': image_gradCAM_cat,
                  'image_without_text': image_without_text,
                  'image_only_text': image_only_text}

        return sample


class gradCamDataset(Dataset):
    """ Dataset for gradCAM generation """

    def __init__(self,
                 root_dir_with_text):
        """
        Args:
            root_dir_with_text (string): Path to the images with text on them.
            root_dir_gradCAM : Path to the gradCAM images (those images fit the root_dir_with_text after passing in the classifier)
        """

        self.root_dir_with_text = root_dir_with_text

        # Generate the data (lists of names of the images in the data sets)
        self.list_with_text = glob.glob(os.path.join(self.root_dir_with_text, '*.png')) + glob.glob(os.path.join(self.root_dir_with_text, '*.jpeg'))

    def __len__(self):
        return len(self.list_with_text)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # read 1 images - image with text.
        # the images are numpy arrays
        image_with_text = io.imread(self.list_with_text[idx], as_gray=True)
        # change the type to float for the normalization
        image_with_text = image_with_text.astype(np.float32)
        # normalization
        if np.max(image_with_text) > 1:
            image_with_text = image_with_text/255.0
        # add the channel to the BW image
        image_with_text = np.expand_dims(image_with_text, axis=0)
        # switch to tensor
        image_with_text = torch.from_numpy(image_with_text)

        # in each sample there are 2 images -
        # 1. with text
        # 2. the name of image 1
        sample = {'image_with_text': image_with_text,
                  'image_with_text_name': self.list_with_text[idx]}

        return sample