import cv2
import os
import os.path as osp
import numpy as np
from config import cfg
import copy
import json
import scipy.io as sio
import random
import math
import torch
import transforms3d
from pycocotools.coco import COCO
import lmdb
from common.utils.posefix import replace_joint_img
from common.utils.smpl import SMPL
from common.utils.preprocessing import load_img, process_bbox, augmentation, compute_iou, load_img_from_lmdb,augmentation_simple
from common.utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton
from common.utils.transforms import world2cam, cam2pixel, pixel2cam, transform_joint_to_other_db
os.environ["CUDA_VISIBLE_DEVICES"]="6" 

class CrowdPose(torch.utils.data.Dataset):
    def __init__(self, transform, data_split):
        print('='*20, 'CrowdPose', '='*20)
        self.transform = transform
        self.data_split = data_split
        self.img_path = osp.join(cfg.root_dir, 'data', 'CrowdPose', 'images')
        self.annot_path = osp.join(cfg.root_dir, 'data', 'CrowdPose', 'annotations')
        self.target_data_split = 'val'
        self.fitting_thr = 5.0  # pixel in cfg.output_hm_shape space

        # mscoco skeleton
        self.coco_joint_num = 18  # original: 17, manually added pelvis
        self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis')
        self.coco_skeleton = ((1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12))
        self.coco_flip_pairs = ((1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16))
        self.coco_joint_regressor = np.load(osp.join(cfg.root_dir, 'data', 'MSCOCO', 'J_regressor_coco_hip_smpl.npy'))

        # crowdpose skeleton
        self.crowdpose_jonit_num = 14+1  # manually added pelvis
        self.crowdpose_joints_name = ('L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Head_top', 'Neck', 'Pelvis')
        self.crowdpose_skeleton = ((0,2), (0,13),  (1,3), (1,13), (2,4), (3,5), (6,14), (7,14), (6,8), (7,9), (8,10), (9,11), (12,13), (13,14) )
        self.crowdpose_flip_pairs = ((0, 1), (1, 2), (3, 4), (5, 6), (6, 7), (8, 9), (10, 11))
        self.crowdpose_coco_common_jidx = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14)  # for posefix, exclude pelvis

        # smpl skeleton
        self.smpl = SMPL()
        self.human_model_layer = self.smpl.layer['neutral'].cuda()
        self.face = self.smpl.face
        self.joint_regressor = self.smpl.joint_regressor
        self.vertex_num = self.smpl.vertex_num
        self.joint_num = self.smpl.joint_num
        self.joints_name = self.smpl.joints_name
        self.flip_pairs = self.smpl.flip_pairs
        self.skeleton = self.smpl.skeleton
        self.root_joint_idx = self.smpl.root_joint_idx
        self.face_kps_vertex = self.smpl.face_kps_vertex

        self.datalist = self.load_data()
        print("crowdpose data len: ", len(self.datalist))

    def add_pelvis(self, joint_coord):
        lhip_idx = self.crowdpose_joints_name.index('L_Hip')
        rhip_idx = self.crowdpose_joints_name.index('R_Hip')
        pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5
        pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2]  # joint_valid
        pelvis = pelvis.reshape(1, 3)
        joint_coord = np.concatenate((joint_coord, pelvis))
        return joint_coord

    def load_data(self):
        datalist = []
        
        if self.data_split == 'train':
            split_list = ['train'] if self.data_split == 'train' else [self.target_data_split]

            datalist = []
            for split in split_list:
                db = COCO(osp.join(self.annot_path, f'crowdpose_{split}.json'))
                # smpl parameter load
                with open(osp.join(self.annot_path, f'CrowdPose_{split}_SMPL_NeuralAnnot.json'), 'r') as f:
                    smpl_params = json.load(f)

                for iid in db.imgs.keys(): # 循环每张图
                    aids = db.getAnnIds([iid]) # 取出当前图的所有person 标注

                    tmplist = []
                    for aid in aids: # 对该图的每个标注循环
                        ann = db.anns[aid] 
                        img = db.loadImgs(ann['image_id'])[0] # 取出 img 的信息
                        img_path = osp.join(self.img_path, img['file_name'])
                        # bbox
                        if split != 'val':  # correct reversed img width,height info
                            width, height = img['height'], img['width']
                        else:
                            width, height = img['width'], img['height']

                        if sum(ann['keypoints']) == 0:
                            continue

                        # bbox
                        tight_bbox = np.array(ann['bbox']) # [x, y, w, h]
                        bbox = process_bbox(tight_bbox, width, height) # 比tight_bbox 大一点 [x, y, w, h]
                        if bbox is None: continue

                        # joint coordinates
                        joint_img = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3) # 3D key point
                        joint_img = self.add_pelvis(joint_img)
                        joint_valid = (joint_img[:, 2].copy().reshape(-1, 1) > 0).astype(np.float32)
                        joint_img[:, 2] = joint_valid[:, 0]  # for posefix, only good for 2d datasets

                        if str(aid) in smpl_params:
                            smpl_param = smpl_params[str(aid)]
                            if smpl_param['fit_err'] < self.fitting_thr:
                                smpl_param = None
                        else:
                            smpl_param = None

                        tmplist.append({'img_path': img_path, 
                                        'img_shape': (height, width), 
                                        'bbox': bbox, 
                                        'tight_bbox': tight_bbox, 
                                        'joint_img': joint_img, 
                                        'joint_valid': joint_valid, 
                                        'neural_annot_result': smpl_param})

                    for i, person in enumerate(tmplist):
                        tight_bbox = person['tight_bbox']

                        # for swap
                        num_overlap = 0
                        near_joints = []
                        other_persons = tmplist[:i] + tmplist[i + 1:]
                        for other in other_persons:
                            other_bbox = other['tight_bbox']
                            iou = compute_iou(tight_bbox[None, :], other_bbox[None, :])
                            if iou < 0.1:
                                continue
                            num_overlap += 1
                            other_joint = transform_joint_to_other_db(other['joint_img'], self.crowdpose_joints_name, self.coco_joints_name)
                            near_joints.append(other_joint)

                        person['num_overlap'] = num_overlap
                        person['near_joints'] = near_joints

                    datalist.extend(tmplist)

        return datalist

    def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape):
        pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans']
        smpl_pose = torch.FloatTensor(pose).view(1, -1).to('cuda');
        smpl_shape = torch.FloatTensor(shape).view(1, -1).to('cuda');  # smpl parameters (pose: 72 dimension, shape: 10 dimension)
        smpl_trans = torch.FloatTensor(trans).view(1, -1).to('cuda')  # translation vector

        # flip smpl pose parameter (axis-angle)
        if do_flip:
            smpl_pose = smpl_pose.view(-1, 3)
            for pair in self.flip_pairs:
                if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose):  # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose.
                    smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone()
            smpl_pose[:, 1:3] *= -1;  # multiply -1 to y and z axis of axis-angle
            smpl_pose = smpl_pose.view(1, -1)

        # get mesh and joint coordinates
        smpl_mesh_coord, smpl_joint_coord = self.human_model_layer(smpl_pose, smpl_shape, smpl_trans)
        smpl_mesh_coord = smpl_mesh_coord.cpu().numpy()
        smpl_joint_coord = smpl_joint_coord.cpu().numpy()
        # incorporate face keypoints
        smpl_mesh_coord = smpl_mesh_coord.astype(np.float32).reshape(-1, 3);
        # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3)
        # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3)
        # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord))
        smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord)

        # flip translation
        if do_flip:  # avg of old and new root joint should be image center.
            focal, princpt = cam_param['focal'], cam_param['princpt']
            flip_trans_x = 2 * (((img_shape[1] - 1) / 2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx, 2])) - 2 * smpl_joint_coord[self.root_joint_idx][0]
            smpl_mesh_coord[:, 0] += flip_trans_x
            smpl_joint_coord[:, 0] += flip_trans_x

        # change to mean shape if beta is too far from it
        smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0.

        return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].cpu().numpy(), smpl_shape[0].cpu().numpy()

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        data = copy.deepcopy(self.datalist[idx])
        img_path, img_shape, bbox = data['img_path'], data['img_shape'], data['bbox']
        if cfg.update_bbox:
            height, width = img_shape
            bbox = process_bbox(data['tight_bbox'], width, height)
        neural_annot_result = data['neural_annot_result']
        if neural_annot_result is not None:
            # use fitted mesh
            smpl_param, cam_param = neural_annot_result['smpl_param'], neural_annot_result['cam_param']
            smpl_mesh_cam_flip, smpl_joint_cam_flip, smpl_pose_flip, smpl_shape_flip = self.get_smpl_coord(smpl_param, cam_param, img_shape=img_shape ,do_flip=True)
            smpl_coord_cam_flip = np.concatenate((smpl_mesh_cam_flip, smpl_joint_cam_flip))
            smpl_coord_img_flip = cam2pixel(smpl_coord_cam_flip, cam_param['focal'], cam_param['princpt'])

            smpl_coord_img_flip = smpl_coord_img_flip[self.vertex_num:]
            smpl_coord_cam_flip = smpl_coord_cam_flip[self.vertex_num:]

            smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, img_shape=img_shape,do_flip=False)
            smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam))
            smpl_coord_img = cam2pixel(smpl_coord_cam, cam_param['focal'], cam_param['princpt'])

            smpl_coord_img = smpl_coord_img[self.vertex_num:]
            smpl_coord_cam = smpl_coord_cam[self.vertex_num:]

            data['smpl_coord_img_raw'] = smpl_coord_img
            data['smpl_coord_img_flip'] = smpl_coord_img_flip
            data['smpl_coord_cam_raw'] = smpl_coord_cam
            data['smpl_coord_cam_flip'] = smpl_coord_cam_flip
            is_valid_fit = True

        else:
            smpl_joint_img = np.zeros((self.joint_num, 3), dtype=np.float32)  # dummy
            smpl_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32)  # dummy
            smpl_mesh_img = np.zeros((self.vertex_num, 3), dtype=np.float32)  # dummy
            smpl_pose = np.zeros((72), dtype=np.float32)  # dummy
            smpl_shape = np.zeros((10), dtype=np.float32)  # dummy
            smpl_joint_trunc = np.zeros((self.joint_num, 1), dtype=np.float32)
            smpl_mesh_trunc = np.zeros((self.vertex_num, 1), dtype=np.float32)
            data['smpl_coord_img_raw'] = None
            data['smpl_coord_img_flip'] = None
            data['smpl_coord_cam_raw'] = None
            data['smpl_coord_cam_flip'] = None
            is_valid_fit = False
        data['is_valid_fit'] = is_valid_fit
        return data
        