import numpy as np
import torch
import joblib

from openai import OpenAI

from retrieval_utils.demo.model import text_encoder
from retrieval_utils.demo.load import load_unit_embeddings, load_splits, load_json
from retrieval_utils.tools.easyconvert import matrix_to, axis_angle_to
from retrieval_utils.transforms.smpl import RotTransDatastruct
from retrieval_utils.rots_to_smpl_conversion.rots_to_smpl import convert_rots_to_h3d

import pickle
import argparse
import os
import datetime
import codecs as cs
import orjson  # loading faster than json
import json

import numpy as np
from tqdm import tqdm

import clip
import time
import random

from retrieval_utils.rots_to_smpl_conversion.rots_to_smpl import convert_rots_to_h3d


DATASET = "humanml3d"

amass_tar_path = "amass.pth.tar" #AMASS data path
ANNOTATIONS_PATH = f"humanml3d/humanml3d_annotations.txt" #HumanML3D annotations path


# Pre-computed part-specific databases for Torso, Hands and Legs
TORSO_PATH = f"torso_outputs/outputs/{DATASET}_guoh3dfeats"
HANDS_PATH = f"hands_outputs/outputs/{DATASET}_guoh3dfeats"
LEGS_PATH = f"legs_outputs/outputs/{DATASET}_guoh3dfeats"

splits_choice = "Unseen" #["Unseen", "all"] #Data retrieval split, `all` to include all the motion samples in retrieval process or `unseen` to include only `unseen` data.

mapper = {'ACCAD': 'ACCAD',
 'BioMotionLab_NTroje': 'BMLrub',
 'CMU': 'CMU',
 'BMLmovi': 'BMLmovi',
 'EKUT': 'EKUT',
 'DFaust_67': 'DFaust67',
 'HumanEva': 'HumanEva',
 'Eyes_Japan_Dataset': 'EyesJapanDataset',
 'KIT': 'KIT',
 'MPI_HDM05': 'MPIHDM05',
 'MPI_Limits': 'MPILimits',
 'MPI_mosh': 'MPImosh',
 'SFU': 'SFU',
 'SSM_synced': 'SSMsynced',
 'TCD_handMocap': 'TCDhandMocap',
 'TotalCapture': 'TotalCapture',
 'Transitions_mocap': 'Transitionsmocap',
 'DanceDB': 'DanceDB',
 'BMLhandball': 'BMLhandball'}

smplh_joints = [
	"pelvis", "left_hip", "right_hip", "spine1", "left_knee", "right_knee",
	"spine2", "left_ankle", "right_ankle", "spine3", "left_foot", "right_foot",
	"neck", "left_collar", "right_collar", "head", "left_shoulder",
	"right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist",
	"left_index1", "left_index2", "left_index3", "left_middle1",
	"left_middle2", "left_middle3", "left_pinky1", "left_pinky2",
	"left_pinky3", "left_ring1", "left_ring2", "left_ring3", "left_thumb1",
	"left_thumb2", "left_thumb3", "right_index1", "right_index2",
	"right_index3", "right_middle1", "right_middle2", "right_middle3",
	"right_pinky1", "right_pinky2", "right_pinky3", "right_ring1",
	"right_ring2", "right_ring3", "right_thumb1", "right_thumb2",
	"right_thumb3", "nose", "right_eye", "left_eye", "right_ear", "left_ear",
	"left_big_toe", "left_small_toe", "left_heel", "right_big_toe",
	"right_small_toe", "right_heel", "left_thumb", "left_index", "left_middle",
	"left_ring", "left_pinky", "right_thumb", "right_index", "right_middle",
	"right_ring", "right_pinky"
]

#Mapping of smpl joints to `torso`, `hands` and `legs`
smpl_bps = {
	"torso" : ['spine1', 'spine2', 'spine3', 'neck', 'head'],
	"hands": ['left_collar', 'left_shoulder', 'left_elbow', 'left_wrist', 'right_collar', 'right_shoulder', 'right_elbow', 'right_wrist',
	"left_index1", "left_index2", "left_index3", "left_middle1",
	"left_middle2", "left_middle3", "left_pinky1", "left_pinky2",
	"left_pinky3", "left_ring1", "left_ring2", "left_ring3", "left_thumb1",
	"left_thumb2", "left_thumb3", "right_index1", "right_index2",
	"right_index3", "right_middle1", "right_middle2", "right_middle3",
	"right_pinky1", "right_pinky2", "right_pinky3", "right_ring1",
	"right_ring2", "right_ring3", "right_thumb1", "right_thumb2",
	"right_thumb3"],
	"legs" : ['left_hip', 'left_knee', 'left_ankle', 'left_foot', 'right_hip', 'right_knee', 'right_ankle', 'right_foot', 'pelvis']
}
torso_joints = [smplh_joints.index(x) for x in smpl_bps["torso"]]
hands_joints = [smplh_joints.index(x) for x in smpl_bps["hands"]]
legs_joints = [smplh_joints.index(x) for x in smpl_bps["legs"]]


def amass_name_getter(fname):
	amass2babel_name = ""
	fname_splits = fname.split("/")
	for i in range(1,len(fname_splits)):
		if(i==1):
			amass2babel_name+=mapper[fname_splits[i]]
		amass2babel_name+="/"
		amass2babel_name+=fname_splits[i]
	return amass2babel_name


def smpl_data_to_matrix_and_trans(data, nohands=True):
	trans = data['trans']
	nframes = len(trans)
	try:
		axis_angle_poses = data['poses']
		axis_angle_poses = data['poses'].reshape(nframes, -1, 3)
	except:
		breakpoint()

	if nohands:
		axis_angle_poses = axis_angle_poses[:, :22]

	matrix_poses = axis_angle_to("matrix", axis_angle_poses)

	return RotTransDatastruct(rots=matrix_poses, trans=trans)


def get_rots_trans(data,frames):
	poses = torch.from_numpy(data['poses'][frames]).float()
	trans = torch.from_numpy(data['trans'][frames]).float()

	smpl_data = {
		"poses": poses,
		"trans": trans
	}

	smpl_data = smpl_data_to_matrix_and_trans(smpl_data, nohands=False)
	return smpl_data.rots, smpl_data.trans

def get_h3d(merged_rots,merged_trans):
	merged_h3d = convert_rots_to_h3d(merged_rots,merged_trans)
	length = len(merged_h3d)
	if length >= crop_size:
		idx = random.randint(0, length - crop_size)
		merged_h3d = merged_h3d[idx: idx + crop_size]
		length = crop_size
	else:
		padding_length = crop_size - length
		D = merged_h3d.shape[1:]
		padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
		merged_h3d = np.concatenate([merged_h3d, padding_zeros], axis=0)

	assert len(merged_h3d) == crop_size

	return merged_h3d
	
def motion_encode(motion_encoder, motion, motion_length, motion_mask, device):
	N = motion.shape[0]
	motion_emb = []
	batch_size = 32
	cur_idx = 0
	with torch.no_grad():
		while cur_idx < N:
			cur_motion = motion[cur_idx: cur_idx + batch_size].to(device)
			cur_motion_length = motion_length[cur_idx: cur_idx + batch_size].to(device)
			cur_motion_mask = motion_mask[cur_idx: cur_idx + batch_size].to(device)
			cur_motion_emb = motion_encoder(cur_motion, cur_motion_length, cur_motion_mask)
			motion_emb.append(cur_motion_emb)
			cur_idx += batch_size
	motion_emb = torch.cat(motion_emb, dim=0)
	return motion_emb

def text_encode(text_encoder, text, token, device):
	N = len(text)
	text_emb = []
	batch_size = 32
	cur_idx = 0
	with torch.no_grad():
		while cur_idx < N:
			cur_text = text[cur_idx: cur_idx + batch_size]
			cur_token = token[cur_idx: cur_idx + batch_size]
			cur_text_emb = text_encoder(cur_text, cur_token, device)
			text_emb.append(cur_text_emb)
			cur_idx += batch_size
	text_emb = torch.cat(text_emb, dim=0)
	return text_emb


def build_from_cfg(cfg, registry, default_args=None):
	if cfg is None:
		return None
	return MMCV_MODELS.build_func(cfg, registry, default_args)

def build_loss(cfg):
	"""Build loss."""
	return LOSSES.build(cfg)

def build_architecture(cfg):
	"""Build framework."""
	return ARCHITECTURES.build(cfg)

def build_submodule(cfg):
	"""Build submodule."""
	return SUBMODULES.build(cfg)

def build_attention(cfg):
	"""Build attention."""
	return ATTENTIONS.build(cfg)


def get_motion_model(name, ckpt_path):
	model = build_submodule(dict(
		type='T2MMotionEncoder',
		input_size=263,
		movement_hidden_size=512,
		movement_latent_size=512,
		motion_hidden_size=1024,
		motion_latent_size=512,
	))
	model.load_pretrained(ckpt_path)
	return model

def get_text_model(name, ckpt_path):
	model = build_submodule(dict(
		type='T2MTextEncoder',
		word_size=300,
		pos_size=15,
		hidden_size=512,
		output_size=512,
		max_text_len=20
	))
	model.load_pretrained(ckpt_path)
	return model
