import pytorch_lightning as pl
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import torch

from data.dataSet import pachesDataset
from data.dataSet import gradCamDataset
from config import PATH_DATASETS_TEXT, PATH_DATASETS_TEXT_CLEAN, PATH_DATASETS_NO_TEXT, PATH_DATASETS_GRADCAM, PATH_DATASETS_ONLY_TEXT

BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir_text: str = PATH_DATASETS_TEXT,
        data_dir_text_clean: str = PATH_DATASETS_TEXT_CLEAN,
        data_dir_gadCAM: str = PATH_DATASETS_GRADCAM,
        data_dir_no_text: str = PATH_DATASETS_NO_TEXT,
        data_dir_only_text: str = PATH_DATASETS_ONLY_TEXT,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir_text = data_dir_text
        self.data_dir_text_clean = data_dir_text_clean
        self.data_dir_gadCAM = data_dir_gadCAM
        self.data_dir_no_text = data_dir_no_text
        self.data_dir_only_text = data_dir_only_text
        self.batch_size = batch_size
        self.num_workers = num_workers

        # self.transform = False
        self.transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(0.4),
                transforms.RandomVerticalFlip(0.4),
            ]
        )

        self.dims = (1, 63, 63)
        self.num_classes = 2

    def setup(self, stage=None):

        # Assign train/val datasets
        if stage == "fit" or stage is None:
            self.img_text_train = pachesDataset(self.data_dir_no_text,
                                                self.data_dir_text,
                                                self.data_dir_text_clean,
                                                self.data_dir_gadCAM,
                                                self.data_dir_only_text,
                                                'train',
                                                self.transform
                                                )

            self.img_text_val = pachesDataset(self.data_dir_no_text,
                                              self.data_dir_text,
                                              self.data_dir_text_clean,
                                              self.data_dir_gadCAM,
                                              self.data_dir_only_text,
                                              'val',
                                              None
                                              # self.transform
                                              )

        # Assign test dataset for use in dataloader
        if stage == "test" or stage is None:
            self.img_text_test = pachesDataset(self.data_dir_no_text,
                                               self.data_dir_text,
                                               self.data_dir_text_clean,
                                               self.data_dir_gadCAM,
                                               self.data_dir_only_text,
                                               'test',
                                               None
                                               )

    def train_dataloader(self):
        return DataLoader(
            self.img_text_train,
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.img_text_val,
                          shuffle=False,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.img_text_test,
                          shuffle=False,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

class GradCamDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir_text: str = PATH_DATASETS_TEXT,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir_text = data_dir_text
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Assign datasets
        self.img_text = gradCamDataset(self.data_dir_text)

    def dataloader(self):
        return DataLoader(
            self.img_text,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )