import os
import random
from collections import defaultdict
from prettyprinter import cpprint  # noqa

import mmcv

SEED = 5546
random.seed(SEED)


COCO_TO_YTVIS_2021 = {
    1: 26, 2: 23, 3: 5, 4: 23, 5: 1, 7: 36, 8: 37, 9: 4, 16: 3, 17: 6, 18: 9, 19: 19, 21: 7, 22: 12, 23: 2, 24: 40,
    25: 18, 34: 14, 35: 31, 36: 31, 41: 29, 42: 33, 43: 34
}

COCO_TO_OVIS = {
    1: 1, 2: 21, 3: 25, 4: 22, 5: 23, 6: 25, 8: 25, 9: 24, 17: 3, 18: 4, 19: 5, 20: 6, 21: 7, 22: 8, 23: 9, 24: 10,
    25: 11
}

_root = os.getenv("DETECTRON2_DATASETS", "../datasets")


pseudo_video_dataset_list = {
    "ytvis2021*": {
        "annos": os.path.join(_root, "ytvis_2021/train.json"),
        "label_map": COCO_TO_YTVIS_2021,
        "out_path": os.path.join(_root, "ytvis_2021/sub_train.json"),
        "video_num": 421,
        "coco_annos": os.path.join(_root, "coco/annotations/coco2ytvis2021_train.json"),
        "coco_output": os.path.join(_root, "coco/annotations/pseudo_coco2ytvis2021_train.json"),
    },
    "ovis*": {
        "annos": os.path.join(_root, "ovis/annotations/train.json"),
        "label_map": COCO_TO_OVIS,
        "out_path": os.path.join(_root, "ovis/annotations/sub_train.json"),
        "video_num": 140,
        "coco_annos": os.path.join(_root, "coco/annotations/coco2ovis_train.json"),
        "coco_output": os.path.join(_root, "coco/annotations/pseudo_coco2ovis_train.json"),
    },
}


def process(anno_path, label_map, video_num=None):
    vis_annos = mmcv.load(anno_path)

    vis_label_list = list(set(label_map.values()))
    out_json = {"annotations": []}
    for k, v in vis_annos.items():
        if k != 'annotations':
            out_json[k] = v

    exist_video_key = []

    for anno in vis_annos['annotations']:
        category_id = anno["category_id"]
        if category_id in vis_label_list:
            out_json['annotations'].append(anno)
            exist_video_key.append(anno['video_id'])
        else:
            pass
    exist_video_key = list(set(exist_video_key))

    if video_num:
        # this process is random.
        exist_video_key = random.sample(exist_video_key, video_num)

    # filter unused images
    new_video_list = []
    print('Start deleting videos.')
    for video in out_json['videos']: # noqa
        if video['id'] in exist_video_key:
            new_video_list.append(video)
    out_json['videos'] = new_video_list

    # filter unused annos
    new_anno_list = []
    cate_num_dict = defaultdict(int)
    for anno in out_json["annotations"]:
        if anno["video_id"] in exist_video_key:
            new_anno_list.append(anno)
            cate_num_dict[anno["category_id"]] += 1
    out_json['annotations'] = new_anno_list
    return out_json


def convert_coco(coco_annos, label_map, cate):
    coco_annos = mmcv.load(coco_annos)
    coco_annos["categories"] = cate

    for anno in coco_annos["annotations"]:
        anno["category_id"] = label_map[anno["category_id"]]

    return coco_annos


def print_information(anno, coco_annos, label_map, name):
    print("###################################################")
    print(f"The information of {name}: ")

    # Step 1: Print Categories List
    cate_list = dict()
    for _, label_id in label_map.items():
        cate_list[label_id] = anno["categories"][label_id - 1]["name"]

    print("Categories List: ", cate_list.values())

    # Step 2: Training / Testing dataset information
    # Training
    print("Training: ")
    print("Images: ", len(coco_annos["images"]))
    print("Instances: ", len(coco_annos["annotations"]))
    print("Masks: ", len(coco_annos["annotations"]))

    # Validation
    print("Testing: ")
    print("Videos: ", len(anno["videos"]))
    print("Instances: ", len(anno["annotations"]))

    num_masks = 0
    for vis_anno in anno["annotations"]:
        num_masks += len(vis_anno["segmentations"])
    print("Masks: ", num_masks)

    # Step 3: Category Distribution
    train_dis = defaultdict(int)
    test_dis = defaultdict(int)

    # train
    for annotation in coco_annos["annotations"]:
        coco_label = annotation["category_id"]
        # vis_label = label_map[coco_label]
        train_dis[coco_label] += 1

    # test
    for annotation in anno["annotations"]:
        vis_label = annotation["category_id"]
        test_dis[vis_label] += 1

    print("###################################################")


def main():
    for key, value_dict in pseudo_video_dataset_list.items():
        label_map = value_dict["label_map"]
        anno_path = value_dict["annos"]
        out_path = value_dict["out_path"]
        video_num = value_dict["video_num"]
        coco_annos = value_dict["coco_annos"]
        coco_output = value_dict["coco_output"]

        # process sub vis datasets
        anno = process(anno_path, label_map, video_num)
        mmcv.dump(anno, out_path)

        # process COCO
        coco_annos = convert_coco(coco_annos, label_map, anno["categories"])
        mmcv.dump(coco_annos, coco_output)

        # print each dataset information
        print_information(anno, coco_annos, label_map, key)

        print(f"{key} is finished! {len(anno['videos'])} videos. ")


if __name__ == '__main__':
    main()
