
import cv2
import math
import json
from PIL import Image
import os.path as op
import numpy as np
import code
from config import cfg
from common.utils.tsv_file import TSVFile, CompositeTSVFile
from common.utils.posefix import replace_joint_img,cs_replace_joint_img
from common.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
from common.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
from common.utils.preprocessing import load_img, process_bbox, augmentation,cs_augmentation, compute_iou, load_img_from_lmdb
from common.utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton
from common.utils.transforms import world2cam, cam2pixel, pixel2cam, transform_joint_to_other_db
import torch
import torchvision.transforms as transforms


class MeshTSVDataset(object):
    def __init__(self, img_file, label_file=None, hw_file=None,
                 linelist_file=None, extrainfo_file = None,is_train=True, cv2_output=False, scale_factor=1):

        self.img_file = img_file
        self.label_file = label_file
        self.hw_file = hw_file
        self.linelist_file = linelist_file
        self.extrainfo_file = extrainfo_file
        self.img_tsv = self.get_tsv_file(img_file)
        self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
        self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
        self.extrainfo_tsv = None if extrainfo_file is None else self.get_tsv_file(extrainfo_file)

        if self.is_composite:
            assert op.isfile(self.linelist_file)
            self.line_list = [i for i in range(self.hw_tsv.num_rows())]
        else:
            self.line_list = load_linelist_file(linelist_file)

        self.cv2_output = cv2_output
        self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        self.is_train = is_train
        if is_train:
            self.data_split='train'
        self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
        self.noise_factor = 0.4
        self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]
        self.img_res = cfg.input_img_shape[0]
        self.sample_res = cfg.output_hm_shape[0]

        self.transform = transforms.ToTensor()
        self.image_keys = self.prepare_image_keys()

        self.joints_name = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
        'L_Elbow','L_Wrist','Neck','Head_top','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
        self.flip_pairs = ((0,5),(1,4),(2,3),(6,11),(7,10),(8,9),(20,21),(22,23))
        self.root_joint_idx = self.joints_name.index('Pelvis')
        self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis')
        self.coco_common_jidx = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14)
        self.base_joints_number = 24
        self.pelvis_index = self.joints_name.index('Pelvis')

    def get_tsv_file(self, tsv_file):
        if tsv_file:
            if self.is_composite:
                return CompositeTSVFile(tsv_file, self.linelist_file,
                        root=self.root)
            tsv_path = find_file_path_in_yaml(tsv_file, self.root)
            return TSVFile(tsv_path)

    def get_valid_tsv(self):
        # sorted by file size
        if self.hw_tsv:
            return self.hw_tsv
        if self.label_tsv:
            return self.label_tsv

    def prepare_image_keys(self):
        tsv = self.get_valid_tsv()
        return [tsv.get_key(i) for i in range(tsv.num_rows())]

    def prepare_image_key_to_index(self):
        tsv = self.get_valid_tsv()
        return {tsv.get_key(i) : i for i in range(tsv.num_rows())}


    def augm_params(self):
        """Get augmentation parameters."""
        flip = 0            # flipping
        pn = np.ones(3)  # per channel pixel-noise
        rot = 0            # rotation
        sc = 1            # scaling
        if self.is_train:
            # We flip with probability 1/2
            if np.random.uniform() <= 0.5:
                flip = 1
	    
            # Each channel is multiplied with a number 
            # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
            pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
	    
            # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
            rot = min(2*self.rot_factor,
                    max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
	    
            # The scale is multiplied with a number
            # in the area [1-scaleFactor,1+scaleFactor]
            sc = min(1+self.scale_factor,
                    max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
            # but it is zero with probability 3/5
            if np.random.uniform() <= 0.6:
                rot = 0
	
        return flip, pn, rot, sc

    def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
        """Process rgb image and do augmentation."""
        rgb_img = crop(rgb_img, center, scale, 
                      [self.img_res, self.img_res], rot=rot)
        # flip the image 
        if flip:
            rgb_img = flip_img(rgb_img)
        # in the rgb image we add pixel noise in a channel-wise manner
        rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
        rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
        rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
        # (3,224,224),float,[0,1]
        rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
        return rgb_img

    def j2d_processing(self, kp, center, scale, r, f):
        """Process gt 2D keypoints and apply all augmentation transforms."""
        nparts = kp.shape[0]
        for i in range(nparts):
            kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, 
                                  [self.img_res, self.img_res], rot=r)
        # convert to normalized coordinates
        kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
        # flip the x coordinates
        if f:
             kp = flip_kp(kp)
        kp = kp.astype('float32')
        return kp
    
    def j2d_sample(self, kp, center, scale, r, f):
        """Process gt 2D keypoints and apply all augmentation transforms."""
        nparts = kp.shape[0]
        for i in range(nparts):
            kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, 
                                  [self.sample_res, self.sample_res], rot=r)
        # convert to normalized coordinates
        # kp[:,:-1] = 2.*kp[:,:-1]/self.sample_res - 1.
        # flip the x coordinates
        if f:
             kp = flip_kp(kp)
        kp = kp.astype('float32')
        return kp

    def j3d_processing(self, S, r, f):
        """Process gt 3D keypoints and apply all augmentation transforms."""
        # in-plane rotation
        rot_mat = np.eye(3)
        if not r == 0:
            rot_rad = -r * np.pi / 180
            sn,cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0,:2] = [cs, -sn]
            rot_mat[1,:2] = [sn, cs]
        S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) 
        # flip the x coordinates
        if f:
            S = flip_kp(S)
        S = S.astype('float32')
        return S

    def pose_processing(self, pose, r, f):
        """Process SMPL theta parameters  and apply all augmentation transforms."""
        # rotation or the pose parameters
        pose = pose.astype('float32')
        pose[:3] = rot_aa(pose[:3], r)
        # flip the pose parameters
        if f:
            pose = flip_pose(pose)
        # (72),float
        pose = pose.astype('float32')
        return pose

    def get_line_no(self, idx):
        return idx if self.line_list is None else self.line_list[idx]

    def get_image(self, idx): 
        line_no = self.get_line_no(idx)
        row = self.img_tsv[line_no]
        # use -1 to support old format with multiple columns.
        cv2_im = img_from_base64(row[-1])
        if self.cv2_output:
            return cv2_im.astype(np.float32, copy=True)
        cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)

        return cv2_im

    def get_hw(self, idx):
        line_no = self.get_line_no(idx)
        if self.hw_tsv is not None:
            row = self.hw_tsv[line_no]
            annotations = json.loads(row[1])
            return annotations
        else:
            return []

    def get_annotations(self, idx):
        line_no = self.get_line_no(idx)
        if self.label_tsv is not None:
            row = self.label_tsv[line_no]
            annotations = json.loads(row[1])
            return annotations
        else:
            return []
    
    def get_extrainfo(self, idx):
        line_no = self.get_line_no(idx)
        if self.extrainfo_tsv is not None:
            row = self.extrainfo_tsv[line_no]
            extrainfo = json.loads(row[1])
            return extrainfo
        else:
            return []

    def get_target_from_annotations(self, annotations, img_size, idx):
        # This function will be overwritten by each dataset to 
        # decode the labels to specific formats for each task. 
        return annotations


    def get_img_info(self, idx):
        if self.hw_tsv is not None:
            line_no = self.get_line_no(idx)
            row = self.hw_tsv[line_no]
            try:
                # json string format with "height" and "width" being the keys
                return json.loads(row[1])[0]
            except ValueError:
                # list of strings representing height and width in order
                hw_str = row[1].split(' ')
                hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
                return hw_dict

    def get_img_key(self, idx):
        line_no = self.get_line_no(idx)
        # based on the overhead of reading each row.
        if self.hw_tsv:
            return self.hw_tsv[line_no][0]
        elif self.label_tsv:
            return self.label_tsv[line_no][0]
        else:
            return self.img_tsv[line_no][0]

    def __len__(self):
        if self.line_list is None:
            return self.img_tsv.num_rows() 
        else:
            return len(self.line_list)

    def __getitem__(self, idx):

        img = self.get_image(idx)
        img_hw = self.get_hw(idx)[0]
        annotations = self.get_annotations(idx)[0]
        extrainfo = self.get_extrainfo(idx)

        annotations = annotations
        height,width = img_hw['height'], img_hw['width']
        center = annotations['center']
        scale = annotations['scale']
        has_2d_joints = annotations['has_2d_joints']
        has_3d_joints = annotations['has_3d_joints']
        joints_2d = np.asarray(annotations['2d_joints'])
        joints_3d = np.asarray(annotations['3d_joints'])
        nearjoints, num_overlap = extrainfo['near_joints'], extrainfo['num_overlap']
        
        if len(joints_2d.shape)>2:
            joints_2d=joints_2d[0]
        if len(joints_3d.shape)>2:
            joints_3d=joints_3d[0]
        joints_2d_valid = (joints_2d[:,2].copy().reshape(-1, 1) > 0).astype(np.float32)

        has_smpl = np.asarray(annotations['has_smpl'])

        try:
            gender = annotations['gender']
        except KeyError:
            gender = 'none'
        
        img, img2bb_trans, bb2img_trans, rot, do_flip = cs_augmentation(img, center,scale, self.data_split)
        img = self.transform(img.astype(np.float32)) / 255.

        if self.is_train:
            joint_img = joints_2d
            joint_valid = joints_2d_valid
            if do_flip:
                joint_img[:, 0] = width - 1 - joint_img[:, 0]
                for pair in self.flip_pairs:
                    joint_img[pair[0], :], joint_img[pair[1], :] = joint_img[pair[1], :].copy(), joint_img[pair[0], :].copy()
                    joint_valid[pair[0], :], joint_valid[pair[1], :] = joint_valid[pair[1], :].copy(), joint_valid[pair[0], :].copy()

            joint_img_xy1 = np.concatenate((joint_img[:, :2], np.ones_like(joint_img[:, :1])), 1)
            joint_img[:, :2] = np.dot(img2bb_trans, joint_img_xy1.transpose(1, 0)).transpose(1, 0)
            # for swap
            if len(nearjoints) > 0:
                near_joint_list = []
                for nj in nearjoints:
                    nj = np.asarray(nj)
                    if len(nj.shape)>2:
                        nj = nj[0]
                    near_joint = np.ones((self.base_joints_number, 3), dtype=np.float32)
                    nj_xy1 = np.concatenate((nj[:, :2], np.ones_like(nj[:, :1])), axis=1)
                    near_joint[:, :2] = np.dot(img2bb_trans, nj_xy1.transpose(1, 0)).transpose(1, 0)
                    near_joint_list.append(near_joint)
                near_joints = np.asarray(near_joint_list, dtype=np.float32)
            else:
                near_joints = np.zeros((1, self.base_joints_number, 3), dtype=np.float32)

            input_joint_img = joint_img.copy()
            joint_img[:, 0] = joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2]
            joint_img[:, 1] = joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1]

            # check truncation
            joint_trunc = joint_valid * ((joint_img[:, 0] >= 0) * (joint_img[:, 0] < cfg.output_hm_shape[2]) * (joint_img[:, 1] >= 0) * (joint_img[:, 1] < cfg.output_hm_shape[1])).reshape(
                -1, 1).astype(np.float32)

            # transform coco joints to target db joints
            # joint_img = transform_joint_to_other_db(joint_img, self.joints_name, self.joints_name)
            joint_cam = np.zeros((self.base_joints_number, 3), dtype=np.float32)  # dummy
            # joint_valid = transform_joint_to_other_db(joint_valid, self.joints_name, self.joints_name)
            # joint_trunc = transform_joint_to_other_db(joint_trunc, self.joints_name, self.joints_name)

            # apply PoseFix
            tmp_joint_img = transform_joint_to_other_db(input_joint_img, self.joints_name, self.coco_joints_name)
            tmp_joint_img = cs_replace_joint_img(tmp_joint_img, center,scale, near_joints, num_overlap , img2bb_trans)
            tmp_joint_img = transform_joint_to_other_db(tmp_joint_img, self.coco_joints_name, self.joints_name)
            input_joint_img[self.coco_common_jidx, :2] = tmp_joint_img[self.coco_common_jidx, :2]
            """
            # debug PoseFix result
            newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_joint_img.T, self.skeleton)
            cv2.imshow(f'{img_path}', newimg)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            cv2.waitKey(1)
            # import pdb; pdb.set_trace()
            """
            input_joint_img[:, 0] = input_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2]
            input_joint_img[:, 1] = input_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1]
            # input_joint_img = transform_joint_to_other_db(input_joint_img, self.joints_name, self.joints_name)

            root_pelvis = joints_3d[self.pelvis_index,:-1]
            joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
            joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, do_flip)

            smpl_pose = np.asarray(annotations['pose'])
            smpl_shape = np.asarray(annotations['betas'])
            smpl_pose = self.pose_processing(smpl_pose, rot, do_flip)
            

            # 3D data rotation augmentation
            rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]], dtype=np.float32)
            # parameter
            smpl_pose = smpl_pose.reshape(-1, 3)
            root_pose = smpl_pose[self.root_joint_idx, :]
            root_pose, _ = cv2.Rodrigues(root_pose)
            root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat, root_pose))
            smpl_pose[self.root_joint_idx] = root_pose.reshape(3)
            smpl_pose = smpl_pose.reshape(-1)
            # smpl coordinate
            # smpl_joint_cam = smpl_joint_cam - smpl_joint_cam[self.root_joint_idx, None]  # root-relative
            # smpl_joint_cam = np.dot(rot_aug_mat, smpl_joint_cam.transpose(1, 0)).transpose(1, 0)

            # SMPL pose parameter validity
            # smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32)
            # for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'):
            #     smpl_param_valid[self.joints_name.index(name)] = 0
            # smpl_param_valid = smpl_param_valid.reshape(-1)

            inputs = {'img': img.float(), 
                      'joints': torch.from_numpy(input_joint_img[:, :2]).float(), 
                      'joints_mask': torch.from_numpy(joint_trunc).float()}
            
            targets = {'joints_2d': torch.from_numpy(joint_img).float(),
                        'joints_3d': torch.from_numpy(joints_3d_transformed).float(),
                        'pose_param': torch.from_numpy(smpl_pose).float(), 
                        'shape_param': torch.from_numpy(smpl_shape).float()}
            
            meta_info = {'orig_joint_valid': torch.from_numpy(joint_valid).float(), 
                         'orig_joint_trunc': torch.from_numpy(joint_trunc).float(), 
                         'has_smpl': has_smpl,
                         'has_3d_joints': has_3d_joints}
            
            return inputs, targets, meta_info


        # if joints_2d.ndim==3:
        #     joints_2d = joints_2d[0]
        # if joints_3d.ndim==3:
        #     joints_3d = joints_3d[0]

        # # Get SMPL parameters, if available
        # has_smpl = np.asarray(annotations['has_smpl'])
        # pose = np.asarray(annotations['pose'])
        # betas = np.asarray(annotations['betas'])

        # try:
        #     gender = annotations['gender']
        # except KeyError:
        #     gender = 'none'

        # # Get augmentation parameters
        # flip,pn,rot,sc = self.augm_params()

        # # Process image
        # img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
        # img = torch.from_numpy(img).float()
        # # Store image before normalization to use it in visualization
        # transfromed_img = self.normalize_img(img)

        # # normalize 3d pose by aligning the pelvis as the root (at origin)
        # root_pelvis = joints_3d[self.pelvis_index,:-1]
        # joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
        # # 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
        # joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
        # # 2d pose augmentation
        # joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
        # joints_2d_sample = self.j2d_sample(joints_2d.copy(), center, sc*scale, rot, flip)

        # ###################################
        # # Masking percantage
        # # We observe that 30% works better for human body mesh. Further details are reported in the paper.
        # mvm_percent = 0.3
        # ###################################
        
        # mjm_mask = np.ones((14,1))
        # if self.is_train:
        #     num_joints = 14
        #     pb = np.random.random_sample()
        #     masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
        #     indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
        #     mjm_mask[indices,:] = 0.0
        # mjm_mask = torch.from_numpy(mjm_mask).float()

        # mvm_mask = np.ones((431,1))
        # if self.is_train:
        #     num_vertices = 431
        #     pb = np.random.random_sample()
        #     masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
        #     indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
        #     mvm_mask[indices,:] = 0.0
        # mvm_mask = torch.from_numpy(mvm_mask).float()

        # vq_mask = np.ones((48,1))
        # if self.is_train:
        #     num_vq = 48
        #     pb = np.random.random_sample()
        #     masked_num = int(pb * mvm_percent * num_vq) # at most x% of the vertices could be masked
        #     indices = np.random.choice(np.arange(num_vq),replace=False,size=masked_num)
        #     vq_mask[indices,:] = 0.0
        # vq_mask = torch.from_numpy(vq_mask).float()

        # meta_data = {}
        # meta_data['ori_img'] = img
        # meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
        # meta_data['betas'] = torch.from_numpy(betas).float()
        # meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
        # meta_data['has_3d_joints'] = has_3d_joints
        # meta_data['has_smpl'] = has_smpl

        # meta_data['mjm_mask'] = mjm_mask
        # meta_data['mvm_mask'] = mvm_mask
        # meta_data['vq_mask']  = vq_mask

        # # Get 2D keypoints and apply augmentation transforms
        # meta_data['has_2d_joints'] = has_2d_joints
        # meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
        # meta_data['joints_2d_sample'] = torch.from_numpy(joints_2d_sample).float()
        # meta_data['scale'] = float(sc * scale)
        # meta_data['center'] = np.asarray(center).astype(np.float32)
        # meta_data['gender'] = gender
        # return img_key, transfromed_img, meta_data



class MeshTSVYamlDataset(MeshTSVDataset):
    """ TSVDataset taking a Yaml file for easy function call
    """
    def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
        self.cfg = load_from_yaml_file(yaml_file)
        self.is_composite = self.cfg.get('composite', False)
        self.root = op.dirname(yaml_file)
        
        if self.is_composite==False:
            img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
            label_file = find_file_path_in_yaml(self.cfg.get('label', None),
                                                self.root)
            hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
            extrainfo_file = find_file_path_in_yaml(self.cfg.get('extrainfo', None), self.root)
            linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
                                                self.root)
        else:
            img_file = self.cfg['img']
            hw_file = self.cfg['hw']
            label_file = self.cfg.get('label', None)
            extrainfo_file = self.cfg.get('extrainfo', None)
            linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
                                                self.root)

        super(MeshTSVYamlDataset, self).__init__(
            img_file, label_file, hw_file, linelist_file,extrainfo_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
