import os
import torch
import random
import numpy as np
import tqdm
import math
import torchvision.transforms as transforms
from typing import  Callable, Dict, Optional, Tuple
from torchvision.datasets import ImageFolder



class CIFA10Dataset(ImageFolder):
    def __init__(
        self,
        root: str, augmentation = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root
        )
        self.imgs = self.samples
        self.augmentation =augmentation
 


    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, label = self.samples[index]
        sensor_measurement = self.loader(path)
        
        sensor_measurement_pair = self.loader(path)


        if self.augmentation:
            cj = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            img_transform = transforms.Compose([
                transforms.Resize(size=[156,156],
                interpolation= transforms.InterpolationMode.BILINEAR),
                transforms.RandomCrop(size=[128,128]),
                transforms.RandomApply(
                [cj],
                    p = 0.8
            ),

                transforms.RandomAffine(degrees=10, translate=[0.1, 0.1]),
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.RandomVerticalFlip(p=0.3),
                transforms.Grayscale(),
                transforms.ToTensor(),
            ])

        else:
            img_transform = transforms.Compose([transforms.Resize(size=[128,128], interpolation= transforms.InterpolationMode.BILINEAR),

                                                transforms.Grayscale(),
                                                transforms.ToTensor(),

                                               ])
            img_transform2 = transforms.Compose([transforms.Resize(size=[128,128], interpolation= transforms.InterpolationMode.BILINEAR),
                                  transforms.ToTensor(),

                                               ])
            sensor_measurement = img_transform(sensor_measurement)#This is the input for the model while testing, turned to gray scale
            sensor_measurement_pair = img_transform2(sensor_measurement_pair)# This is for data visulization
            return sensor_measurement, label, sensor_measurement_pair

        sensor_measurement = img_transform(sensor_measurement)
        sensor_measurement_pair = img_transform(sensor_measurement_pair)




        return sensor_measurement, label, sensor_measurement_pair
