import os
import json
import re

from collections import OrderedDict

# 原始列表
raw_list = [
    "e0_raw", "e1_raw", "e2_raw", "e3_raw",
    "e4_raw", "e5_raw", "e6_raw", "e7_raw",
    "e8_raw", "e9_raw", "e10_raw", "e11_raw",
    "s1_all_raw", "s2_all_raw", "s3_all_raw",
    "s4_all_raw", "s5_all_raw", "s6_all_raw"
]

# 创建有序字典，键为列表项，值为它们在列表中的位置（索引）
ordered_dict = OrderedDict((item, index) for index, item in enumerate(raw_list))

#print(ordered_dict)

def extract_frame_number(frame_name):
    """
    提取帧名称中的数字部分。
    例如，将 '_0q5WkK91xU_13_1_frame0002' 提取为 '0002'。
    """
    match = re.search(r'frame(\d+)', frame_name)
    return int(match.group(1)) if match else None

def extract_cam_number(frame_name):
    """
    提取帧名称中的数字部分。
    例如，将 '_0q5WkK91xU_13_1_frame0002' 提取为 '0002'。
    """
    match = frame_name.split("_")[0]
    return int(match) if match else None

def extract_cam_number_l2(frame_name):
    """
    提取帧名称中的数字部分。
    例如，将 '_0q5WkK91xU_13_1_frame0002' 提取为 '0002'。
    """
    #print(frame_name)
    if frame_name == "0026":
        return 0
    match = frame_name.split("_cam")[1]
    return int(match) if match else None

def order_list(frame_name, ordered_dict=ordered_dict):

    res = os.path.split(frame_name)[-1].split("_cam")[0] #.split("orig_")[1]
    print(res)
    return ordered_dict.get(res.lower(), float('inf'))


def generate_new_json(data_dir, output_json_path, training_testing):
    data_dict = {}

    training = ['0041', '0048', '0094', '0100', '0116',  '0168', '0175', '0189',  '0232', '0259', '0250']
    testing = ['0278', '0290', '0295', '0297', '0026', '0156', '0195', '0099', "0262"]
    ty = ["e0_raw","e1_raw","e2_raw","e3_raw",
          "e4_raw","e5_raw","e6_raw","e7_raw",
          "e8_raw","e9_raw","e10_raw","e11_raw",
          "s1_all_raw","s2_all_raw","s3_all_raw",
          "s4_all_raw","s5_all_raw","s6_all_raw"]

    target = training if training_testing == "training" else testing
    
    frame_path_list = []
    shapes_list = []

    cam_list = sorted(
        os.listdir(data_dir),
        key=lambda x: (extract_cam_number(x), extract_cam_number_l2(x))
    )

    
    # 遍历数据目录
    #0041_cam12
    for video_name in cam_list:

        # check id_cam in target or not
        if video_name.split("_")[0] not in target:
            continue
        video_path = os.path.join(data_dir, video_name)
        print(f"process {video_name}")
        if not os.path.isdir(video_path):
            continue  # 跳过不是目录的文件
        
        
        # 遍历视频目录中的帧目录，并排序
        cam_dirs = sorted(
            os.listdir(video_path),
            key=lambda x: extract_frame_number(x)
        )

        for frame_name_idx in range(len(cam_dirs)):
            frame_name = cam_dirs[frame_name_idx]
            frame_path = os.path.join(video_path, frame_name)
            if not os.path.isdir(frame_path):
                continue  # 跳过不是目录的文件
            
            #print(frame_path)

            # 构建 _inputs.jpg 和 _shape_images.jpg 的路径
            inputs_path = os.path.join(frame_path, f"{frame_name}_inputs.jpg")
            shapes_path = os.path.join(frame_path, f"{frame_name}_shape_detail_images.jpg")

            # 检查文件是否存在
            if os.path.exists(inputs_path):
                frame_path_list.append(inputs_path)
            if os.path.exists(shapes_path):
                shapes_list.append(shapes_path)
        
        if len(frame_path_list) != len(shapes_list) or len(frame_path_list) < 10:
            print(f"!!! invalid {video_name} !!!")
            continue

    print(len(frame_path_list))
    #breakpoint()
    #frame_path_list = sorted(frame_path_list, key=lambda x: order_list(x))

    frame_path_list_final = []
    shapes_list_final = []


    for idx in range(0, len(frame_path_list)):
        vid = os.path.split(frame_path_list[idx])[-1].split("_cam")[0] #.split("orig_")[1]
        
        # 保存到数据字典中
        if vid not in data_dict:
            print(vid)
            data_dict[vid] = {'clip_data_list': []}
            for i in range(27):

                data_dict[vid]["clip_data_list"].append({
                    'frame_path_list': [],
                    'shapes_list': []
                })

        cam_idx = int(os.path.split(frame_path_list[idx])[-1].split("_cam")[1].split("_")[0])
        print(cam_idx-12)
        #print(data_dict[vid])

        data_dict[vid]['clip_data_list'][cam_idx-12]['frame_path_list'].append(frame_path_list[idx])
        data_dict[vid]['clip_data_list'][cam_idx-12]['shapes_list'].append(shapes_list[idx])

        if idx %10000 == 0:
            with open(output_json_path, 'w') as f:
                json.dump(data_dict, f, indent=2)
        #breakpoint()


    # 将数据写入 JSON 文件
    with open(output_json_path, 'w') as f:
        json.dump(data_dict, f, indent=2)

# 调用函数生成 JSON
data_directory = '/mnt/znzz/jus/RenderMe-360_release_data_20id_full/multiview_flame_0811'  # 数据目录
output_json_file = '/root/users/jusjus/AniPortrait/dataset/renderme_multi_detailed.json'  # 输出 JSON 文件路径
training_testing = "training"
generate_new_json(data_directory, output_json_file, training_testing)

"""def count_root_entries(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return len(data)

num_entries = count_root_entries(output_json_file)
print(f"根级条目数量: {num_entries}")
"""