import os
import sys
import numpy as np
import pandas as pd
import torch
import json

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET

from PIL import Image

from skimage import io, transform
from skimage.color import gray2rgb
from skimage.transform import resize

import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.datasets.folder import ImageFolder
from torchvision.datasets.utils import verify_str_arg
from torchvision.datasets import VOCDetection, CocoDetection

WNID_TO_WORD_FILE = 'words.txt'
VAL_ANNO_FILE = 'val/val_annotations.txt'

mask_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(256, interpolation=Image.NEAREST),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
])


class VisualGenome(Dataset):

    def __init__(self, root_dir=None, transform=None, max_batch_size=64):
        if root_dir:
            self._root_dir = root_dir
        else:
            self._root_dir = './image_data/visual_genome_python_driver-master'
        self._transform = transform
        self._samples = self._load_obj()
        self._labels = self._load_labels()
        self._max_batch_size = max_batch_size
        self._lvl_supervision = self._load_obj_map()

    def __len__(self):
        return len(self._samples)

    def __getitem__(self, idx):
        path = 'images/VG_100K/%d.jpg' % self._samples[idx]['image_id']
        path = os.path.join(self._root_dir, path)

        ori_img = Image.open(path).convert('RGB')
        if self._transform is not None:
            img = self._transform(ori_img)
        else:
            img = ori_img

        targets = []
        masks = []
        data = self._samples[idx]
        for obj in self._samples[idx]['objects']:
            target = torch.zeros((len(self._labels),))
            if obj in self._labels:
                label = self._labels[obj]
                target[label] = 1
            box_mask = torch.zeros(ori_img.size)
            for box_anno in self._samples[idx]['objects'][obj]:
                xmin = box_anno['x']
                xmax = box_anno['x'] + box_anno['w']
                ymin = box_anno['y']
                ymax = box_anno['y'] + box_anno['h']
                box_mask[ymin:ymax, xmin:xmax] = 1
            targets.append(target)
            box_mask = mask_transforms(box_mask)
            # box_mask = box_mask - torch.min(box_mask)
            # if torch.max(box_mask) > 0:
            #     box_mask = box_mask / torch.max(box_mask)
            masks.append(box_mask)

        return img, torch.stack(targets), torch.stack(masks)

    def _load_obj(self):
        dataFile = os.path.join(self._root_dir, 'our_objects_per_image_cleaned.json')
        with open(dataFile) as f:
            data = json.load(f)
        return data

    def _load_obj_map(self):
        dataFile = os.path.join(self._root_dir, 'our_object_map_per_image_cleaned.json')
        with open(dataFile) as f:
            data = json.load(f)
        count = torch.zeros((len(data),))
        for i in range(len(data)):
            count[i] = len(list(data.values())[i])
        return count

    def _load_labels(self):
        dataFile = os.path.join(self._root_dir, 'obj_per_image_cleaned_idx_convert.pkl')
        with open(dataFile, 'rb') as f:
            data = pickle.load(f)
        return data


class MyCocoClassification(CocoDetection):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        # for t in target:
        #     mask = coco.annToMask(t)
        target_onehot = torch.zeros((91,))
        for t in target:
            target_onehot[t['category_id'] - 1] = 1

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target_onehot


class MyCocoSemantic(CocoDetection):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        # for t in target:
        #     mask = coco.annToMask(t)
        label_embedding_file = "./image_data/coco/label_embedding.pth"
        label_embedding = torch.load(label_embedding_file)
        target_onehot = torch.zeros((len(label_embedding['itos']),))
        for t in target:
            idxs = list(label_embedding['stoi'].values())[t['category_id'] - 1]
            for idx in idxs:
                i = list(label_embedding['itos'].keys()).index(idx)
                target_onehot[i] = 1

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target_onehot


class MyCocoJoint(CocoDetection):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        # for t in target:
        #     mask = coco.annToMask(t)
        label_embedding_file = "./image_data/coco/label_embedding.pth"
        label_embedding = torch.load(label_embedding_file)
        cls_onehot = torch.zeros((91,))
        vsf_onehot = torch.zeros((len(label_embedding['itos']),))
        for t in target:
            cls_onehot[t['category_id'] - 1] = 1
            idxs = list(label_embedding['stoi'].values())[t['category_id'] - 1]
            for idx in idxs:
                i = list(label_embedding['itos'].keys()).index(idx)
                vsf_onehot[i] = 1

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, cls_onehot, vsf_onehot


class MyCocoDetection(CocoDetection):
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        # for t in target:
        #     mask = coco.annToMask(t)
        label_embedding_file = "./image_data/coco/label_embedding.pth"
        label_embedding = torch.load(label_embedding_file)
        target_list = []
        for t in target:
            mask = coco.annToMask(t)
            im = Image.fromarray(mask)
            mask = np.array(im.resize((224, 224)))
            obj = list(label_embedding['stoi'].keys())[t['category_id'] - 1]
            idxs = list(label_embedding['stoi'].values())[t['category_id'] - 1]
            mask_dict = {'mask': mask, 'object': obj, 'idx': idxs}
            target_list.append(mask_dict)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target_list


class MyCocoSegmentation(CocoDetection):
    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
        super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.anns.keys()))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        ann_id = self.ids[index]
        ann = self.coco.anns[ann_id]
        img_id = ann['image_id']

        mask = coco.annToMask(ann)
        # print(np.amax(mask))
        label_embedding_file = "./image_data/coco/label_embedding.pth"
        label_embedding = torch.load(label_embedding_file)
        target_onehot = torch.zeros((len(label_embedding['itos']),))
        idxs = list(label_embedding['stoi'].values())[ann['category_id'] - 1]
        for idx in idxs:
            i = list(label_embedding['itos'].keys()).index(idx)
            target_onehot[i] = 1

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        mask = mask_transforms(mask)
        mask = mask_transforms(mask)
        mask = mask - torch.min(mask)
        if torch.max(mask) > 0:
            mask = mask / torch.max(mask)

        return img, target_onehot, mask


class PoliticianFB(Dataset):

    def __init__(self, pkl_file, root_dir, transforms=None):
        self._pkl_file = pkl_file
        self._root_dir = root_dir
        self._transform = transforms
        self._samples = self._extract_and_load()

    def __len__(self):
        return len(self._samples)

    def __getitem__(self, idx):
        path = self._samples['image_inputs'].iloc[idx]
        path = os.path.join(self._root_dir, path)
        # Get image and resize as raw sizes not consistent.
        # Raw image is gray scale 0-1 float so need to convert.
        # Convert to int as PIL expects int.
        img = np.uint8(transform.resize(
            io.imread(path),
            (224, 224)
        ) * 255)
        # Convert to 3d, as model expects 3d.
        # img = np.asarray([img, img, img])
        ori_img = gray2rgb(img)

        if self._transform is not None:
            img = self._transform(ori_img)
        else:
            img = ori_img

        return {
            'images': ori_img,
            'photos': img,
            'genders': np.array(
                [self._samples['gender_labels_0'].iloc[idx], self._samples['gender_labels_1'].iloc[idx]]),
            'parties': np.array(
                [self._samples['party_labels_0'].iloc[idx], self._samples['party_labels_1'].iloc[idx]])
        }

    def _extract_and_load(self):
        pkl_path = os.path.join(self._root_dir, self._pkl_file)
        with open(pkl_path, 'rb') as f:
            pdict = pickle.load(f)
        # Convert to pandas data frames.
        img_df = pd.DataFrame(pdict['image_inputs'], columns=['image_inputs'])
        gender_df = pd.DataFrame(pdict['gender_labels'], columns=['gender_labels_0', 'gender_labels_1'])
        party_df = pd.DataFrame(pdict['party_labels'], columns=['party_labels_0', 'party_labels_1'])

        df = pd.concat([img_df, gender_df, party_df], axis=1)
        return df


class SUNSelected(ImageFolder):
    """`Selected SUN Database.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class name tuples.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, **kwargs):
        root = self.root = os.path.expanduser(root)

        super(SUNSelected, self).__init__(self.root, **kwargs)
        self.root = root


class SUNObjectDatabase(VOCDetection):
    def __init__(self,
                 root,
                 image_set='train',
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
        valid_sets = ["train", "test"]
        self.image_set = verify_str_arg(image_set, "image_set", valid_sets)

        base_dir = 'SUN2012pascalformat'
        sun_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(sun_root, 'JPEGImages')
        annotation_dir = os.path.join(sun_root, 'Annotations')

        if not os.path.isdir(sun_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        splits_dir = os.path.join(sun_root, 'ImageSets/Main')

        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')
        try:
            target = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
        except:
            target = {}

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target