import numpy as np
import torch
import joblib

from openai import OpenAI

from retrieval_utils.demo.model import text_encoder
from retrieval_utils.demo.load import load_unit_embeddings, load_splits, load_json
from retrieval_utils.tools.easyconvert import matrix_to, axis_angle_to
from retrieval_utils.transforms.smpl import RotTransDatastruct

import pickle
import argparse
import os
import datetime
import codecs as cs
import orjson  # loading faster than json
import json

import numpy as np
from tqdm import tqdm

import clip
import time
import random

import utils

from retrieval_utils.rots_to_smpl_conversion.rots_to_smpl import convert_rots_to_h3d

import prompt

def humanml3d_keyid_to_babel_rendered_url(keyid):
	# Don't show the mirrored version of HumanMl3D
	if "M" in keyid:
		return None

	dico = h3d_index[keyid]
	path = dico["path"]

	# HumanAct12 motions are not rendered online
	# so we skip them for now
	if "humanact12" in path:
		return None

	# This motion is not rendered in BABEL
	# so we skip them for now
	if path not in amass_to_babel:
		return None

	babel_id = amass_to_babel[path].zfill(6)
	url = f"https://babel-renders.s3.eu-central-1.amazonaws.com/{babel_id}.mp4"

	# For the demo, we retrieve from the first annotation only
	ann = dico["annotations"][0]
	start = ann["start"]
	end = ann["end"]
	text = ann["text"]

	data = {
		"url": url,
		"start": start,
		"end": end,
		"text": text,
		"keyid": keyid,
		"babel_id": babel_id,
		"path": path,
	}

	return data


def retrieve_helper(
	model,
	unit_motion_embs,
	all_keyids,
	text,
	keyids_index,
	index_keyids,    
	split="test",
	nmax=8,
	body_part="original",
):
	keyids = [x for x in all_keyids[split] if x in keyids_index]
	index = [keyids_index[x] for x in keyids]

	unit_embs = unit_motion_embs[index]

	scores = model.compute_scores(text, unit_embs=unit_embs)

	keyids = np.array(keyids)
	sorted_idxs = np.argsort(-scores)
	best_keyids = keyids[sorted_idxs]
	best_scores = scores[sorted_idxs]

	datas = []
	for keyid, score in zip(best_keyids, best_scores):
		if len(datas) == nmax:
			break

		data = humanml3d_keyid_to_babel_rendered_url(keyid)
		if data is None:
			continue
		data["score"] = round(float(score), 2)
		data["input_text"] = text
		datas.append(data)

	return datas

def retrieve_function(
	all_keyids,
	text,
	torso_model,
	torso_unit_motion_embs,
	torso_keyids_index,
	torso_index_keyids,
	hands_model,
	hands_unit_motion_embs,
	hands_keyids_index,
	hands_index_keyids,    
	legs_model,
	legs_unit_motion_embs,
	legs_keyids_index,
	legs_index_keyids,    
	torso_input,
	hands_input,
	legs_input,
	split="test",
	nmax=8,
):

	# Using part-specific descriptions if user provided any
	torso_text = torso_input
	hands_text = hands_input
	legs_text = legs_input

	# Prompting LLM for part-specific descriptions if user haven't provided any
	if(len(torso_text)==0 or len(hands_text)==0 or len(legs_text)==0):
		req_prompt = prompt.replace("[ACTION]",str(text))
		gpt_texts = GPT_Completion([req_prompt])[0]
		torso_text = gpt_texts.split("1) Torso:")[1].split("2) Hands:")[0].strip()
		hands_text = gpt_texts.split("2) Hands:")[1].split("3) Fingers:")[0].strip()
		legs_text = gpt_texts.split("4) Legs:")[1].strip()

	# Retrieval of part-specific data
	torso_datas = retrieve_helper(torso_model,torso_unit_motion_embs,all_keyids,torso_text,torso_keyids_index,torso_index_keyids,split=split,nmax=nmax,body_part="torso")
	hands_datas = retrieve_helper(hands_model,hands_unit_motion_embs,all_keyids,hands_text,hands_keyids_index,hands_index_keyids,split=split,nmax=nmax,body_part="hands")
	legs_datas = retrieve_helper(legs_model,legs_unit_motion_embs,all_keyids,legs_text,legs_keyids_index,legs_index_keyids,split=split,nmax=nmax,body_part="legs")

	# Returning part-specific data as per the ranking order of retrieval, as we follow ran-by-rank composition method
	return_datas = []
	for i in range(len(torso_datas)):
		return_datas.append(torso_datas[i])
		return_datas.append(hands_datas[i])
		return_datas.append(legs_datas[i])

	return return_datas

def morag_retrieve(text):
	
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	# Part-specific text encoders (To calculate the text embedding of generated part-specific descriptions for input query)
	torso_model = text_encoder(TORSO_PATH).to(device)
	hands_model = text_encoder(HANDS_PATH).to(device)
	legs_model = text_encoder(LEGS_PATH).to(device)

	# Loading pre-computed part-specific databases to perform retrieval
	torso_unit_motion_embs, torso_keyids_index, torso_index_keyids = load_unit_embeddings(
		TORSO_PATH, DATASET, device
	)
	hands_unit_motion_embs, hands_keyids_index, hands_index_keyids = load_unit_embeddings(
		HANDS_PATH, DATASET, device
	)
	legs_unit_motion_embs, legs_keyids_index, legs_index_keyids = load_unit_embeddings(
		LEGS_PATH, DATASET, device
	)

	all_keyids = load_splits(DATASET, splits=["test", "all"])

	h3d_index = load_json(f"{DATASET}/annotations.json")
	amass_to_babel = load_json("amass_to_babel.json")
	amass2babel = load_json("amass-path2babel.json")
	babel2amass = {}

	for i in amass2babel:
		babel2amass[amass2babel[i]["babel_id"]]=i

	amass_data = joblib.load(amass_tar_path)
	amass_enumerator = enumerate(amass_data)

	amass_id_idx_map = {}
	amass_name_idx_map = {}

	for i, sample in amass_enumerator:
		amass_name_idx_map[sample['fname'].split("UNZIPPED_datasets/")[-1].split(".")[0]] = i

	if (splits_choice=="Unseen"):
		split = "test"
	else:
		split = "all"


	merged_rots = {}
	merged_trans = {}

	return_motions = []
	return_m_lengths = []
	return_text_feats = []

	nvids = 1 #Number of motions to retrieve

	# If user wants to provide his own choice of part-specific descriptions instead of using LLM to generate
	torso_input = ""
	hands_input = ""
	legs_input = ""

	# part-specific data is retrieved into `datas`

	datas = retrieve_function(all_keyids = all_keyids, text=text, 
		torso_model=torso_model, torso_unit_motion_embs=torso_unit_motion_embs, torso_keyids_index=torso_keyids_index, torso_index_keyids=torso_index_keyids,
		hands_model=hands_model, hands_unit_motion_embs=hands_unit_motion_embs, hands_keyids_index=hands_keyids_index, hands_index_keyids=hands_index_keyids,
		legs_model=legs_model, legs_unit_motion_embs=legs_unit_motion_embs, legs_keyids_index=legs_keyids_index, legs_index_keyids=legs_index_keyids,
		torso_input=torso_input, hands_input=hands_input, legs_input=legs_input,
		split=split, nmax=nvids)


	for vid in nvids:
		torso_babel_id = str(int(datas[vid*3+0]['babel_id']))
		hands_babel_id = str(int(datas[vid*3+1]['babel_id']))
		legs_babel_id = str(int(datas[vid*3+2]['babel_id']))

		torso_key_id = str(int(datas[vid*3+0]['keyid']))
		hands_key_id = str(int(datas[vid*3+1]['keyid']))
		legs_key_id = str(int(datas[vid*3+2]['keyid']))

		torso_fname = "/".join(babel2amass[torso_babel_id].split("/")[1:]).split(".")[0]
		hands_fname = "/".join(babel2amass[hands_babel_id].split("/")[1:]).split(".")[0]
		legs_fname = "/".join(babel2amass[legs_babel_id].split("/")[1:]).split(".")[0]

		torso_data = amass_data[amass_name_idx_map[torso_fname]]
		hands_data = amass_data[amass_name_idx_map[hands_fname]]
		legs_data = amass_data[amass_name_idx_map[legs_fname]]

		torso_frames = np.arange(int(torso_data['fps']*datas[vid*3+0]['start']),int(torso_data['fps']*datas[vid*3+0]['end']))
		hands_frames = np.arange(int(hands_data['fps']*datas[vid*3+1]['start']),int(hands_data['fps']*datas[vid*3+1]['end']))
		legs_frames = np.arange(int(legs_data['fps']*datas[vid*3+2]['start']),int(legs_data['fps']*datas[vid*3+2]['end']))

		# STEP 1 - Calculating the minimum frames in the retrieved part-specific motions, on which we will perform spatial composition
		frames = min([len(torso_frames),len(hands_frames),len(legs_frames)])

		#calculating motion rots and trans for the required number of frames for torso, hands and legs retrieved samples
		motion_data_rots = {} 
		motion_data_trans = {} 

		motion_data_rots["torso"], motion_data_trans["torso"] = get_rots_trans(torso_data, torso_frames[:frames])
		motion_data_rots["hands"], motion_data_trans["hands"] = get_rots_trans(hands_data, hands_frames[:frames])
		motion_data_rots["legs"], motion_data_trans["legs"] = get_rots_trans(legs_data, legs_frames[:frames])	

		# STEP 2 - spatial composition of three retrieved motion sequences in rank-by-rank method
		compositioned_rots = torch.zeros(motion_data_rots["torso"].shape[0],52,3,3)
		compositioned_rots[:, torso_joints] = motion_data_rots["torso"][:, torso_joints] #copying torso joints from torso retrieved sample
		compositioned_rots[:, hands_joints] = motion_data_rots["hands"][:, hands_joints] #copying hands joints from hands retrieved sample
		compositioned_rots[:, legs_joints] = motion_data_rots["legs"][:, legs_joints] #copying legs joints from legs retrieved sample

		# STEP 3 - copying translation from leg retrieved sample
		compositioned_trans = motion_data_trans["legs"] 

		merged_rots[str(vid)] = compositioned_rots
		merged_trans[str(vid)] = compositioned_trans

		#convert to h3d representation which is suitable for conditioning on diffusion model	
		h3d = np.array(get_h3d(compositioned_rots,compositioned_trans))

		text_feats = encode_text(text)

		return_motions = np.concatenate((return_motions,np.expand_dims(h3d, axis=0)), axis=0)
		return_m_lengths.append(len(compositioned_rots))
		return_text_feats = np.concatenate((return_text_feats.cpu(),text_feats.cpu()), axis=0)

	# returning the composed motions and text features to diffusion model for conditioning
	return return_motions, return_m_lengths, return_text_feats

if __name__ == '__main__':
	text = "A person is standing on one leg with hands wide open" #Input text for which we want to perform MoRAG retrieval.

	output = morag_retrieve(text)
	# print(output.keys())