# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path
import pickle

import numpy as np
import torch.distributed as dist
from mmcv.runner import get_dist_info

from .builder import DATASETS
from .cifar import CIFAR10
from .utils import download_and_extract_archive


@DATASETS.register_module()
class CoarseCIFAR100(CIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset."""

    base_folder = 'cifar-100-python'
    url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
    filename = 'cifar-100-python.tar.gz'
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    CLASSES = [
        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
        'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
        'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
        'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab',
        'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish',
        'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
        'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man',
        'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
        'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
        'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
        'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
        'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
        'willow_tree', 'wolf', 'woman', 'worm'
    ]

    CLASS_MAP = {
        0: 4, 1: 1, 2: 14, 3: 8, 4: 0,
        5: 6, 6: 7, 7: 7, 8: 18, 9: 3,
        10: 3, 11: 14, 12: 9, 13: 18, 14: 7,
        15: 11, 16: 3, 17: 9, 18: 7, 19: 11,
        20: 6, 21: 11, 22: 5, 23: 10, 24: 7,
        25: 6, 26: 13, 27: 15, 28: 3, 29: 15,
        30: 0, 31: 11, 32: 1, 33: 10, 34: 12,
        35: 14, 36: 16, 37: 9, 38: 11, 39: 5,
        40: 5, 41: 19, 42: 8, 43: 8, 44: 15,
        45: 13, 46: 14, 47: 17, 48: 18, 49: 10,
        50: 16, 51: 4, 52: 17, 53: 4, 54: 2,
        55: 0, 56: 17, 57: 4, 58: 18, 59: 17,
        60: 10, 61: 3, 62: 2, 63: 12, 64: 12,
        65: 16, 66: 12, 67: 1, 68: 9, 69: 19,
        70: 2, 71: 10, 72: 0, 73: 1, 74: 16,
        75: 12, 76: 9, 77: 13, 78: 15, 79: 13,
        80: 16, 81: 19, 82: 2, 83: 4, 84: 6,
        85: 19, 86: 5, 87: 5, 88: 8, 89: 19,
        90: 18, 91: 1, 92: 2, 93: 15, 94: 6,
        95: 0, 96: 17, 97: 8, 98: 14, 99: 13
    }

    def load_annotations(self):

        rank, world_size = get_dist_info()

        if rank == 0 and not self._check_integrity():
            download_and_extract_archive(
                self.url,
                self.data_prefix,
                filename=self.filename,
                md5=self.tgz_md5)

        if world_size > 1:
            dist.barrier()
            assert self._check_integrity(), \
                'Shared storage seems unavailable. ' \
                f'Please download the dataset manually through {self.url}.'

        if not self.test_mode:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.imgs = []
        self.gt_labels = []

        # load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.data_prefix, self.base_folder,
                                     file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    self.gt_labels.extend(entry['labels'])
                else:
                    self.gt_labels.extend(entry['fine_labels'])

        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.gt_labels = [self.CLASS_MAP[k] for k in self.gt_labels]

        self._load_meta()

        data_infos = []
        for img, gt_label in zip(self.imgs, self.gt_labels):
            gt_label = np.array(gt_label, dtype=np.int64)
            info = {'img': img, 'gt_label': gt_label}
            data_infos.append(info)
        return data_infos
