import torch
from PIL import Image
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import warnings
from typing import Tuple, List, Optional
from torch import Tensor
import numbers
from skimage.color import rgb2hed, hed2rgb


def _interpolation_modes_from_int(i: int) -> T.InterpolationMode:
    inverse_modes_mapping = {
        0: T.InterpolationMode.NEAREST,
        2: T.InterpolationMode.BILINEAR,
        3: T.InterpolationMode.BICUBIC,
        4: T.InterpolationMode.BOX,
        5: T.InterpolationMode.HAMMING,
        1: T.InterpolationMode.LANCZOS,
    }
    return inverse_modes_mapping[i]


"""
****************CNN****************
"""


class Resize3D(torch.nn.Module):
    def __init__(self, size, interpolation=T.InterpolationMode.BILINEAR, max_size=None, antialias=None):
        super().__init__()
        if not isinstance(size, (int, tuple)):
            raise TypeError("Size should be int or tuple. Got {}".format(type(size)))
        if isinstance(size, tuple) and len(size) not in (1, 2):
            raise ValueError("If size is a tuple, it should have 1 or 2 values")
        self.size = size
        self.max_size = max_size

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        self.interpolation = interpolation
        self.antialias = antialias

    def forward(self, imgs):
        return [TF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) for img in imgs]

    def __repr__(self):
        interpolate_str = self.interpolation.value
        return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
            self.size, interpolate_str, self.max_size, self.antialias)


class CenterCrop3D(torch.nn.Module):
    @staticmethod
    def _setup_size(size, error_msg):
        if isinstance(size, numbers.Number):
            return int(size), int(size)

        if isinstance(size, tuple) and len(size) == 1:
            return size[0], size[0]

        if len(size) != 2:
            raise ValueError(error_msg)

        return size

    def __init__(self, size):
        super().__init__()
        self.size = self._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

    def forward(self, imgs):
        return [TF.center_crop(img, self.size) for img in imgs]

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)


class RandomHorizontalFlip3D(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, imgs):
        if torch.rand(1) < self.p:
            return [TF.hflip(img) for img in imgs]
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomVerticalFlip3D(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, imgs):
        if torch.rand(1) < self.p:
            return [TF.vflip(img) for img in imgs]
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class ColorJitter3D(torch.nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, apply_idx=None):
        super().__init__()
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)
        self.apply_idx = apply_idx

    @torch.jit.unused
    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - float(value), center + float(value)]
            if clip_first_on_zero:
                value[0] = max(value[0], 0.0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    def get_params(brightness: Optional[List[float]],
                   contrast: Optional[List[float]],
                   saturation: Optional[List[float]],
                   hue: Optional[List[float]]
                   ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
        fn_idx = torch.randperm(4)

        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))

        return fn_idx, b, c, s, h

    def forward(self, imgs):
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
            self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

        for fn_id in fn_idx:
            if fn_id == 0 and brightness_factor is not None:
                imgs = [TF.adjust_brightness(img, brightness_factor) if self.apply_idx is None or i in self.apply_idx else img for i, img in enumerate(imgs)]
            elif fn_id == 1 and contrast_factor is not None:
                imgs = [TF.adjust_contrast(img, contrast_factor) if self.apply_idx is None or i in self.apply_idx else img for i, img in enumerate(imgs)]
            elif fn_id == 2 and saturation_factor is not None:
                imgs = [TF.adjust_saturation(img, saturation_factor) if self.apply_idx is None or i in self.apply_idx else img for i, img in enumerate(imgs)]
            elif fn_id == 3 and hue_factor is not None:
                imgs = [TF.adjust_hue(img, hue_factor) if self.apply_idx is None or i in self.apply_idx else img for i, img in enumerate(imgs)]

        return imgs

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string


class RandomRotation3D(torch.nn.Module):
    @staticmethod
    def _check_sequence_input(x, name, req_sizes):
        msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
        if not isinstance(x, (list, tuple)):
            raise TypeError("{} should be a sequence of length {}.".format(name, msg))
        if len(x) not in req_sizes:
            raise ValueError("{} should be sequence of length {}.".format(name, msg))

    @staticmethod
    def _setup_angle(x, name, req_sizes=(2, )):
        if isinstance(x, numbers.Number):
            if x < 0:
                raise ValueError("If {} is a single number, it must be positive.".format(name))
            x = [-x, x]
        else:
            RandomRotation3D._check_sequence_input(x, name, req_sizes)

        return [float(d) for d in x]

    def __init__(
        self, degrees, interpolation=T.InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
    ):
        super().__init__()
        if resample is not None:
            warnings.warn(
                "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
            )
            interpolation = _interpolation_modes_from_int(resample)

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        self.degrees = self._setup_angle(degrees, name="degrees", req_sizes=(2, ))

        if center is not None:
            self._check_sequence_input(center, "center", req_sizes=(2, ))

        self.center = center

        self.resample = self.interpolation = interpolation
        self.expand = expand

        if fill is None:
            fill = 0
        elif not isinstance(fill, (list, tuple, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

        self.fill = fill

    @staticmethod
    def get_params(degrees: List[float]) -> float:
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
        return angle

    def forward(self, imgs):
        fill = self.fill
        if isinstance(imgs[0], Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * TF.get_image_num_channels(imgs[0])
            else:
                fill = [float(f) for f in fill]
        angle = self.get_params(self.degrees)

        return [TF.rotate(img, angle, self.resample, self.expand, self.center, fill) for img in imgs]

    def __repr__(self):
        interpolate_str = self.interpolation.value
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
        format_string += ', interpolation={0}'.format(interpolate_str)
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
        if self.fill is not None:
            format_string += ', fill={0}'.format(self.fill)
        format_string += ')'
        return format_string


class ToTensor3D:
    def __init__(self, no_mask=False):
        self.no_mask = no_mask
        
    def __call__(self, imgs):
        if self.no_mask:
            imgs = [TF.to_tensor(img) for img in imgs]
            return imgs
        else:
            img = TF.to_tensor(imgs[0])
            mask = torch.from_numpy(np.array(imgs[1])/255).long()
            return [img, mask]

    def __repr__(self):
        return self.__class__.__name__ + '()'


class Normalize3D(torch.nn.Module):
    def __init__(self, mean, std, inplace=False, apply_idx=None):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace
        self.apply_idx = apply_idx

    def forward(self, tensors):
        return [TF.normalize(tensor, self.mean, self.std, self.inplace) if self.apply_idx is None or i in self.apply_idx else tensor for i, tensor in enumerate(tensors)]

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class Denormalize(torch.nn.Module):
    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def forward(self, tensor):
        if not self.inplace:
            tensor = tensor.clone()
        dtype = tensor.dtype
        mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
        std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
        if (std == 0).any():
            raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
        if mean.ndim == 1:
            mean = mean.view(-1, 1, 1)
        if std.ndim == 1:
            std = std.view(-1, 1, 1)
        tensor.mul_(std).add_(mean)
        return tensor

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class AdjustHE3D:
    def __init__(self, h=0.3, e=0.3, apply_idx=None):
        self.h = h
        self.e = e
        self.apply_idx = apply_idx

    def adjust(self, img, h, e):
        img = rgb2hed(np.array(img))
        img[:,:,0] *= h
        img[:,:,1] *= e
        img = hed2rgb(img)
        return Image.fromarray(np.uint8(img*255))

    def __call__(self, imgs):
        np.random.rand
        h = np.maximum(np.random.rand(1) * self.h * 2 + 1 - self.h, 0)
        e = np.maximum(np.random.rand(1) * self.e * 2 + 1 - self.e, 0)
        imgs = [self.adjust(img, h, e) if self.apply_idx is None or i in self.apply_idx else img for i, img in enumerate(imgs)]
        return imgs


"""
****************Graph****************
"""


class ResizeG(torch.nn.Module):
    def __init__(self, size, interpolation=T.InterpolationMode.BILINEAR, max_size=None, antialias=None, link=True):
        super().__init__()
        if not isinstance(size, (int, tuple)):
            raise TypeError("Size should be int or tuple. Got {}".format(type(size)))
        if isinstance(size, tuple) and len(size) not in (1, 2):
            raise ValueError("If size is a tuple, it should have 1 or 2 values")
        self.size = size
        self.max_size = max_size
        self.link = link

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        self.interpolation = interpolation
        self.antialias = antialias

    def forward(self, imgs):
        return [TF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) for img in imgs[:-2 if self.link else -1]] + imgs[-2 if self.link else -1:]

    def __repr__(self):
        interpolate_str = self.interpolation.value
        return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
            self.size, interpolate_str, self.max_size, self.antialias)


class RandomHorizontalFlipG(torch.nn.Module):
    def __init__(self, input_size, p=0.5, link=True):
        super().__init__()
        if not isinstance(input_size, (int, tuple)):
            raise TypeError("input_size should be int or tuple. Got {}".format(type(input_size)))
        if isinstance(input_size, tuple) and len(input_size) not in (1, 2):
            raise ValueError("If input_size is a tuple, it should have 1 or 2 values")
        if isinstance(input_size, int):
            self.fea_size = [(input_size+31)//32, (input_size+31)//32]
        else:
            self.fea_size = [(size+31)//32 for size in input_size]
        self.p = p
        self.lut = torch.arange(self.fea_size[0]*self.fea_size[1]).view(self.fea_size[0], self.fea_size[1])
        self.lut = torch.flip(self.lut, dims=[1]).view(-1)
        self.link = link

    def forward(self, imgs):
        if torch.rand(1) < self.p:
            if self.link:
                res = [TF.hflip(img) for img in imgs[:-1]]
                link_res = imgs[-1].clone()
                for value in torch.unique(imgs[-1]):
                    link_res[imgs[-1]==value] = self.lut[value]
                res.append(link_res)
                return res
            else:
                return [TF.hflip(img) for img in imgs]
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + f'(fea_size={self.fea_size}, p={self.p})'


class RandomVerticalFlipG(torch.nn.Module):
    def __init__(self, input_size, p=0.5, link=True):
        super().__init__()
        if not isinstance(input_size, (int, tuple)):
            raise TypeError("input_size should be int or tuple. Got {}".format(type(input_size)))
        if isinstance(input_size, tuple) and len(input_size) not in (1, 2):
            raise ValueError("If input_size is a tuple, it should have 1 or 2 values")
        if isinstance(input_size, int):
            self.fea_size = [(input_size+31)//32, (input_size+31)//32]
        else:
            self.fea_size = [(size+31)//32 for size in input_size]
        self.p = p
        self.lut = torch.arange(self.fea_size[0]*self.fea_size[1]).view(self.fea_size[0], self.fea_size[1])
        self.lut = torch.flip(self.lut, dims=[0]).view(-1)
        self.link = link

    def forward(self, imgs):
        if torch.rand(1) < self.p:
            if self.link:
                res = [TF.vflip(img) for img in imgs[:-1]]
                link_res = imgs[-1].clone()
                for value in torch.unique(imgs[-1]):
                    link_res[imgs[-1]==value] = self.lut[value]
                res.append(link_res)
                return res
            else:
                return [TF.vflip(img) for img in imgs]
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + f'(fea_size={self.fea_size}, p={self.p})'


class ColorJitterG(torch.nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, link=True):
        super().__init__()
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)
        self.link = link

    @torch.jit.unused
    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - float(value), center + float(value)]
            if clip_first_on_zero:
                value[0] = max(value[0], 0.0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    def get_params(brightness: Optional[List[float]],
                   contrast: Optional[List[float]],
                   saturation: Optional[List[float]],
                   hue: Optional[List[float]]
                   ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
        fn_idx = torch.randperm(4)

        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))

        return fn_idx, b, c, s, h

    def forward(self, imgs):
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
            self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

        for fn_id in fn_idx:
            if fn_id == 0 and brightness_factor is not None:
                imgs = [TF.adjust_brightness(img, brightness_factor) for img in imgs[:-2 if self.link else -1]] + imgs[-2 if self.link else -1:]
            elif fn_id == 1 and contrast_factor is not None:
                imgs = [TF.adjust_contrast(img, contrast_factor) for img in imgs[:-2 if self.link else -1]] + imgs[-2 if self.link else -1:]
            elif fn_id == 2 and saturation_factor is not None:
                imgs = [TF.adjust_saturation(img, saturation_factor) for img in imgs[:-2 if self.link else -1]] + imgs[-2 if self.link else -1:]
            elif fn_id == 3 and hue_factor is not None:
                imgs = [TF.adjust_hue(img, hue_factor) for img in imgs[:-2 if self.link else -1]] + imgs[-2 if self.link else -1:]

        return imgs

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string


class ToTensorG:
    def __init__(self, link=True, multilable=False):
        self.link = link
        self.multilable = multilable

    def __call__(self, imgs):
        if self.link:
            tensors = list(map(lambda img: TF.to_tensor(img), imgs[:-2]))
            if self.multilable:
                mask = torch.from_numpy(np.array(imgs[-2])).long()
            else:
                mask = torch.from_numpy(np.array(imgs[-2])/255).long()
            return [*tensors, mask, imgs[-1]]
        else:
            tensors = list(map(lambda img: TF.to_tensor(img), imgs[:-1]))
            if self.multilable:
                mask = torch.from_numpy(np.array(imgs[-1])).long()
            else:
                mask = torch.from_numpy(np.array(imgs[-1])/255).long()
            return [*tensors, mask]


    def __repr__(self):
        return self.__class__.__name__ + '()'
