import numpy as np
import torch
import joblib

from openai import OpenAI

from retrieval_utils.demo.model import text_encoder
from retrieval_utils.demo.load import load_unit_embeddings, load_splits, load_json
from retrieval_utils.tools.easyconvert import matrix_to, axis_angle_to
from retrieval_utils.transforms.smpl import RotTransDatastruct

import pickle
import argparse
import os
import datetime
import codecs as cs
import orjson  # loading faster than json
import json

import numpy as np
from tqdm import tqdm

import clip
import time
import random

import utils

from retrieval_utils.rots_to_smpl_conversion.rots_to_smpl import convert_rots_to_h3d

import morag

class RetrievalDatabase(nn.Module):

    def __init__(self,
                 num_retrieval=None,
                 topk=None,
                 retrieval_file=None,
                 latent_dim=512,
                 output_dim=512,
                 num_layers=2,
                 num_motion_layers=4,
                 kinematic_coef=0.1,
                 max_seq_len=196,
                 num_heads=8,
                 ff_size=1024,
                 stride=4,
                 sa_block_cfg=None,
                 ffn_cfg=None,
                 dropout=0):
        super().__init__()
        self.num_retrieval = num_retrieval
        self.topk = topk
        self.latent_dim = latent_dim
        self.stride = stride
        self.kinematic_coef = kinematic_coef
        self.num_layers = num_layers
        self.num_motion_layers = num_motion_layers
        self.max_seq_len = max_seq_len
        data = np.load(retrieval_file)
        self.text_features = torch.Tensor(data['text_features'])
        self.captions = data['captions']
        self.motions = data['motions']
        self.m_lengths = data['m_lengths']
        self.clip_seq_features = data['clip_seq_features']
        self.train_indexes = data.get('train_indexes', None)
        self.test_indexes = data.get('test_indexes', None)

        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.motion_proj = nn.Linear(self.motions.shape[-1], self.latent_dim)
        self.motion_pos_embedding = nn.Parameter(torch.randn(max_seq_len, self.latent_dim))
        self.motion_encoder_blocks = nn.ModuleList()
        for i in range(num_motion_layers):
            self.motion_encoder_blocks.append(
                EncoderLayer(
                    sa_block_cfg=sa_block_cfg,
                    ffn_cfg=ffn_cfg
                )
            )
        TransEncoderLayer = nn.TransformerEncoderLayer(
            d_model=self.latent_dim,
            nhead=num_heads,
            dim_feedforward=ff_size,
            dropout=dropout,
            activation="gelu")
        self.text_encoder = nn.TransformerEncoder(
            TransEncoderLayer,
            num_layers=num_layers)
        self.results = {}

    def extract_text_feature(self, text, clip_model, device):
        text = clip.tokenize([text], truncate=True).to(device)
        with torch.no_grad():
            text_features = clip_model.encode_text(text)
        return text_features
    
    def encode_text(self, text, clip_model, device):
        with torch.no_grad():
            text = clip.tokenize([text], truncate=True).to(device)
            x = clip_model.token_embedding(text).type(clip_model.dtype)  # [batch_size, n_ctx, d_model]

            x = x + clip_model.positional_embedding.type(clip_model.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = clip_model.transformer(x)
            x = clip_model.ln_final(x).type(clip_model.dtype)

        # B, T, D
        xf_out = x.permute(1, 0, 2)
        return xf_out

    def retrieve(self, caption, length, clip_model, device, idx=None):
        if self.training and self.train_indexes is not None and idx is not None:
            idx = idx.item()
            indexes = self.train_indexes[idx]
            data = []
            cnt = 0
            for retr_idx in indexes:
                if retr_idx != idx:
                    data.append(retr_idx)
                    cnt += 1
                    if cnt == self.topk:
                        break
            random.shuffle(data)
            return data[:self.num_retrieval]
        
        elif not self.training and self.test_indexes is not None and idx is not None:
            idx = idx.item()
            indexes = self.test_indexes[idx]
            data = []
            cnt = 0
            for retr_idx in indexes:
                data.append(retr_idx)
                cnt += 1
                if cnt == self.topk:
                    break
            # random.shuffle(data)
            return data[:self.num_retrieval]
        else:
            value = hash(caption)
            if value in self.results:
                return self.results[value]
            text_feature = self.extract_text_feature(caption, clip_model, device)
            
            rel_length = torch.LongTensor(self.m_lengths).to(device)
            rel_length = torch.abs(rel_length - length) / torch.clamp(rel_length, min=length)

            semantic_score = F.cosine_similarity(self.text_features.to(device), text_feature)
            kinematic_score = torch.exp(-rel_length * self.kinematic_coef)
            score = semantic_score * kinematic_score
            indexes = torch.argsort(score, descending=True)
            data = []
            cnt = 0
            for idx in indexes:
                caption, motion, m_length = self.captions[idx], self.motions[idx], self.m_lengths[idx]
                if not self.training or m_length != length:
                    cnt += 1
                    data.append(idx.item())
                    if cnt == self.num_retrieval:
                        self.results[value] = data
                        return data
        assert False

    def generate_src_mask(self, T, length):
        B = len(length)
        src_mask = torch.ones(B, T)
        for i in range(B):
            for j in range(length[i], T):
                src_mask[i, j] = 0
        return src_mask

    def forward(self, captions, lengths, clip_model, device, idx=None):
        mean_path = "mean.npy"
        std_path = "std.npy"
        mean = np.load(mean_path)
        std = np.load(std_path)        
        B = len(captions)
        B_motions, B_m_lengths, text_feats = morag.morag_retrieve(captions)
        for i in range(len(B_motions)):
            B_motions[i] = (B_motions[i] - mean) / (std + 1e-9)  

        all_motions = torch.from_numpy(B_motions).float().to(device)
        all_m_lengths = torch.Tensor(B_m_lengths).long()

        T = all_motions.shape[1]
        src_mask = self.generate_src_mask(T, all_m_lengths).to(device)
        raw_src_mask = src_mask.clone()

        re_motion = self.motion_proj(all_motions) + self.motion_pos_embedding.unsqueeze(0)
        for module in self.motion_encoder_blocks:
            re_motion = module(x=re_motion, src_mask=src_mask.unsqueeze(-1))
        re_motion = re_motion.view(B, self.num_retrieval, T, -1).contiguous()
        re_motion = re_motion[:, :, ::self.stride, :].contiguous()
        
        src_mask = src_mask[:, ::self.stride].contiguous()
        src_mask = src_mask.view(B, self.num_retrieval, -1).contiguous()

        T = 77
        all_text_seq_features = torch.Tensor(text_feats).to(device)
        all_text_seq_features = all_text_seq_features.permute(1, 0, 2).float()

        re_text = self.text_encoder(all_text_seq_features)
        re_text = re_text.permute(1, 0, 2).view(B, self.num_retrieval, T, -1).contiguous()
        re_text = re_text[:, :, -1:, :].contiguous()
        
        del(B_motions)
        del(B_m_lengths)
        del(B_torso_text_feats)
        del(B_hands_text_feats)
        del(B_legs_text_feats)
        re_dict = dict(
            re_text=re_text,
            re_motion=re_motion,
            re_mask=src_mask,
            raw_motion=all_motions,
            raw_motion_length=all_m_lengths,
            raw_motion_mask=raw_src_mask)
        return re_dict