import os
import os.path as osp
import numpy as np
import torch
import cv2
import random
import json
import math
import copy
import pickle
import transforms3d
from pycocotools.coco import COCO
from config import cfg
import lmdb
from common.utils.tsv_file import TSVFile, CompositeTSVFile
from common.utils.image_ops import img_from_base64
from common.utils.posefix import replace_joint_img
from src.utils.tsv_file_ops import load_linelist_file
from common.utils.smpl import SMPL
from common.utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation, compute_iou, load_img_from_lmdb
from common.utils.transforms import world2cam, cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db
# from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton, vis_bbox
import transforms3d
os.environ["CUDA_VISIBLE_DEVICES"]="6" 

class My_MuCo(torch.utils.data.Dataset):
    def __init__(self, transform, data_split):
        print('='*20, 'MuCo', '='*20)
        self.transform = transform
        self.data_split = data_split
        self.img_dir = osp.join(cfg.root_dir, 'data', 'MuCo')
        self.annot_path = osp.join(cfg.root_dir, 'data', 'MuCo', 'MuCo-3DHP.json')
        self.smpl_param_path = osp.join(cfg.root_dir, 'data', 'MuCo', 'smpl_param.json')
        self.img_file = osp.join(cfg.metro_dir, 'muco', 'train.img.tsv')
        self.hw_file = osp.join(cfg.metro_dir, 'muco', 'train.hw.tsv')
        self.img_tsv = self.get_tsv_file(self.img_file)
        self.hw_tsv = self.get_tsv_file(self.hw_file)
        self.fitting_thr = 25 # milimeter

        self.linelist_file = osp.join(cfg.metro_dir, 'muco', 'train.linelist.tsv')
        self.line_list = load_linelist_file(self.linelist_file)

        # COCO joint set
        self.coco_joint_num = 17  # original: 17
        self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle')

        # MuCo joint set
        self.muco_joint_num = 21
        self.muco_joints_name = ('Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', 'Head', 'R_Hand', 'L_Hand', 'R_Toe', 'L_Toe')
        self.muco_flip_pairs = ( (2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13), (17, 18), (19, 20) )
        self.muco_skeleton = ( (0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9), (9, 10), (10, 19), (11, 12), (12, 13), (13, 20), (1, 2), (2, 3), (3, 4), (4, 17), (1, 5), (5, 6), (6, 7), (7, 18) )
        self.muco_root_joint_idx = self.muco_joints_name.index('Pelvis')
        self.muco_coco_common_jidx = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)

        # H36M joint set
        self.h36m_joint_regressor = np.load(osp.join(cfg.root_dir, 'data', 'Human36M', 'J_regressor_h36m_correct.npy')) # use h36m joint regrssor (only use subset from original muco joint set)
        self.h36m_flip_pairs = ( (1, 4), (2, 5), (3, 6), (14, 11), (15, 12), (16, 13) )
        self.h36m_joints_name = ('Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
        self.h36m_root_joint_idx = self.h36m_joints_name.index('Pelvis')

        # SMPL joint set
        self.smpl = SMPL()
        self.human_model_layer = self.smpl.layer['neutral'].cuda()
        self.face = self.smpl.face
        self.joint_regressor = self.smpl.joint_regressor
        self.vertex_num = self.smpl.vertex_num
        self.joint_num = self.smpl.joint_num
        self.joints_name = self.smpl.joints_name
        self.flip_pairs = self.smpl.flip_pairs
        self.skeleton = self.smpl.skeleton
        self.root_joint_idx = self.smpl.root_joint_idx
        self.face_kps_vertex = self.smpl.face_kps_vertex

        self.datalist = self.load_data()
        print("muco data len: ", len(self.datalist))

        self.image_keys = self.prepare_image_key_to_index()
        for k in list(self.image_keys.keys())[:5]:
            print(k)

    def get_line_no(self, idx):
        return idx if self.line_list is None else self.line_list[idx]

    def prepare_image_key_to_index(self):
        tsv = self.hw_tsv
        return {tsv.get_key(i) : i for i in range(tsv.num_rows())}


    def get_tsv_file(self, tsv_path):
        # if tsv_file:
        #     if self.is_composite:
        #         return CompositeTSVFile(tsv_file, self.linelist_file,
        #                 root=self.root)
        #     tsv_path = find_file_path_in_yaml(tsv_file, self.root)
        return TSVFile(tsv_path)

    def get_image(self, idx): 
        line_no = self.get_line_no(idx)
        row = self.img_tsv[line_no]
        # use -1 to support old format with multiple columns.
        cv2_im = img_from_base64(row[-1])
        # if self.cv2_output:
        #     return cv2_im.astype(np.float32, copy=True)
        cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)

        return cv2_im

    def load_data(self):
        if self.data_split == 'train':
            db = COCO(self.annot_path)
            with open(self.smpl_param_path) as f:
                smpl_params = json.load(f)
        else:
            print('Unknown data subset')
            assert 0
        
        datalist = []
        for iid in db.imgs.keys():
            img = db.imgs[iid]
            img_id = img["id"]
            img_width, img_height = img['width'], img['height']
            imgname = img['file_name']
            # img_path = osp.join(self.img_dir, imgname)
            img_path = imgname
            focal = img["f"]
            princpt = img["c"]
            cam_param = {'focal': focal, 'princpt': princpt}

            # crop the closest person to the camera
            ann_ids = db.getAnnIds(img_id)
            anns = db.loadAnns(ann_ids)

            root_depths = [ann['keypoints_cam'][self.muco_root_joint_idx][2] for ann in anns]
            closest_pid = root_depths.index(min(root_depths))
            pid_list = [closest_pid]
            for pid in pid_list:
                joint_cam = np.array(anns[pid]['keypoints_cam'])
                joint_img = np.array(anns[pid]['keypoints_img'])
                joint_img = np.concatenate([joint_img, joint_cam[:,2:]],1)
                joint_valid = np.ones((self.muco_joint_num,1))
                tight_bbox = np.array(anns[pid]['bbox'])

                # for swap
                num_overlap = 0
                near_joints = []
                other_persons = anns[:pid] + anns[pid+1:]
                for other in other_persons:
                    other_tight_bbox = np.array(other['bbox'])
                    iou = compute_iou(tight_bbox[None, :], other_tight_bbox[None, :])
                    if iou < 0.1:
                        continue
                    num_overlap += 1
                    other_joint = np.array(other['keypoints_img'])
                    other_joint = np.concatenate((other_joint, np.ones_like(other_joint[:, :1])), axis=1)
                    other_joint = transform_joint_to_other_db(other_joint, self.muco_joints_name, self.coco_joints_name)
                    near_joints.append(other_joint)
                if num_overlap == 0:
                    near_joints = []

                bbox = process_bbox(tight_bbox, img_width, img_height)
                if bbox is None: continue
                
                # check smpl parameter exist
                try:
                    smpl_param = smpl_params[str(ann_ids[pid])]
                except KeyError:
                    smpl_param = None

                datalist.append({
                    'img_path': img_path,
                    'img_shape': (img_height, img_width),
                    'bbox': bbox,
                    'tight_bbox': tight_bbox,
                    'joint_img': joint_img,
                    'joint_cam': joint_cam,
                    'joint_valid': joint_valid,
                    'cam_param': cam_param,
                    'smpl_param': smpl_param,
                    'near_joints': near_joints,
                    'num_overlap': num_overlap
                })

        return datalist

    def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape):
        pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans']
        smpl_pose = torch.FloatTensor(pose).view(1, -1).to('cuda');
        smpl_shape = torch.FloatTensor(shape).view(1, -1).to('cuda');  # smpl parameters (pose: 72 dimension, shape: 10 dimension)
        smpl_trans = torch.FloatTensor(trans).view(1, -1).to('cuda')  # translation vector

        # flip smpl pose parameter (axis-angle)
        if do_flip:
            smpl_pose = smpl_pose.view(-1, 3)
            for pair in self.flip_pairs:
                if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose):  # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose.
                    smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone()
            smpl_pose[:, 1:3] *= -1;  # multiply -1 to y and z axis of axis-angle
            smpl_pose = smpl_pose.view(1, -1)

        # get mesh and joint coordinates
        smpl_mesh_coord, smpl_joint_coord = self.human_model_layer(smpl_pose, smpl_shape, smpl_trans)
        smpl_mesh_coord = smpl_mesh_coord.cpu().numpy()
        smpl_joint_coord = smpl_joint_coord.cpu().numpy()
        # incorporate face keypoints
        smpl_mesh_coord = smpl_mesh_coord.astype(np.float32).reshape(-1, 3);
        # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3)
        # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3)
        # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord))
        smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord)

        # flip translation
        if do_flip:  # avg of old and new root joint should be image center.
            focal, princpt = cam_param['focal'], cam_param['princpt']
            flip_trans_x = 2 * (((img_shape[1] - 1) / 2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx, 2])) - 2 * smpl_joint_coord[self.root_joint_idx][0]
            smpl_mesh_coord[:, 0] += flip_trans_x
            smpl_joint_coord[:, 0] += flip_trans_x

        # change to mean shape if beta is too far from it
        smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0.

        # meter -> milimeter
        smpl_mesh_coord *= 1000; smpl_joint_coord *= 1000;
        return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].cpu().numpy(), smpl_shape[0].cpu().numpy()

    def get_fitting_error(self, muco_joint, smpl_mesh, do_flip):
        muco_joint = muco_joint.copy()
        muco_joint = muco_joint - muco_joint[self.muco_root_joint_idx,None,:] # root-relative
        if do_flip:
            muco_joint[:,0] = -muco_joint[:,0]
            for pair in self.muco_flip_pairs:
                muco_joint[pair[0],:] , muco_joint[pair[1],:] = muco_joint[pair[1],:].copy(), muco_joint[pair[0],:].copy()
        muco_joint_valid = np.ones((self.muco_joint_num,3), dtype=np.float32)
      
        # transform to h36m joint set
        h36m_joint = transform_joint_to_other_db(muco_joint, self.muco_joints_name, self.h36m_joints_name)
        h36m_joint_valid = transform_joint_to_other_db(muco_joint_valid, self.muco_joints_name, self.h36m_joints_name)
        h36m_joint = h36m_joint[h36m_joint_valid==1].reshape(-1,3)

        h36m_from_smpl = np.dot(self.h36m_joint_regressor, smpl_mesh)
        h36m_from_smpl = h36m_from_smpl[h36m_joint_valid==1].reshape(-1,3)
        h36m_from_smpl = h36m_from_smpl - np.mean(h36m_from_smpl,0)[None,:] + np.mean(h36m_joint,0)[None,:] # translation alignment
        error = np.sqrt(np.sum((h36m_joint - h36m_from_smpl)**2,1)).mean()
        return error

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        # if not hasattr(self, 'aug_lmdb'):
        #     aug_db_path = osp.join(cfg.root_dir, 'data', 'MuCo', 'data', 'augmented_set_lmdb', 'augmented_set_lmdb.lmdb')
        #     aug_env = lmdb.open(aug_db_path,
        #                     subdir=os.path.isdir(aug_db_path),
        #                     readonly=True, lock=False,
        #                     readahead=False, meminit=False)
        #     self.aug_lmdb = aug_env.begin(write=False)
        #     unaug_db_path = osp.join(cfg.root_dir, 'data', 'MuCo', 'data', 'unaugmented_set_lmdb', 'unaugmented_set_lmdb.lmdb')
        #     unaug_env = lmdb.open(unaug_db_path,
        #                     subdir=os.path.isdir(unaug_db_path),
        #                     readonly=True, lock=False,
        #                     readahead=False, meminit=False)
        #     self.unaug_lmdb = unaug_env.begin(write=False)
        data = copy.deepcopy(self.datalist[idx])
        img_path, img_shape, bbox, smpl_param, cam_param = data['img_path'], data['img_shape'], data['bbox'], data['smpl_param'], data['cam_param']
        if smpl_param is not None:
            smpl_mesh_cam_flip, smpl_joint_cam_flip, smpl_pose_flip, smpl_shape_flip = self.get_smpl_coord(smpl_param, cam_param, img_shape=img_shape ,do_flip=True)
            smpl_coord_cam_flip = np.concatenate((smpl_mesh_cam_flip, smpl_joint_cam_flip))
            smpl_coord_img_flip = cam2pixel(smpl_coord_cam_flip, cam_param['focal'], cam_param['princpt'])

            smpl_coord_img_flip = smpl_coord_img_flip[self.vertex_num:]
            smpl_coord_cam_flip = smpl_coord_cam_flip[self.vertex_num:]

            smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, img_shape=img_shape,do_flip=False)
            smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam))
            smpl_coord_img = cam2pixel(smpl_coord_cam, cam_param['focal'], cam_param['princpt'])

            smpl_coord_img = smpl_coord_img[self.vertex_num:]
            smpl_coord_cam = smpl_coord_cam[self.vertex_num:]

            data['smpl_coord_img_raw'] = smpl_coord_img
            data['smpl_coord_img_flip'] = smpl_coord_img_flip
            data['smpl_coord_cam_raw'] = smpl_coord_cam
            data['smpl_coord_cam_flip'] = smpl_coord_cam_flip
            is_valid_fit = True

        else:
            smpl_joint_img = np.zeros((self.joint_num, 3), dtype=np.float32)  # dummy
            smpl_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32)  # dummy
            smpl_mesh_img = np.zeros((self.vertex_num, 3), dtype=np.float32)  # dummy
            smpl_pose = np.zeros((72), dtype=np.float32)  # dummy
            smpl_shape = np.zeros((10), dtype=np.float32)  # dummy
            smpl_joint_trunc = np.zeros((self.joint_num, 1), dtype=np.float32)
            smpl_mesh_trunc = np.zeros((self.vertex_num, 1), dtype=np.float32)
            data['smpl_coord_img_raw'] = None
            data['smpl_coord_img_flip'] = None
            data['smpl_coord_cam_raw'] = None
            data['smpl_coord_cam_flip'] = None
            is_valid_fit = False
        data['is_valid_fit'] = is_valid_fit
        return data

