# This code is based on https://github.com/openai/guided-diffusion
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
from utils.fixseed import fixseed
import os
import numpy as np
import torch
from utils.parser_util import generate_args
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from utils import dist_util
from model.cfg_sampler import ClassifierFreeSampleModel
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.humanml.utils.plot_script import plot_3d_motion
import shutil
from data_loaders.tensors import collate
from data_loaders.humanml.data.dataset import convert_axisto_rot6d
import copy

def main():

    args = generate_args()
    fixseed(args.seed)
    out_path = args.output_dir
    name = os.path.basename(os.path.dirname(args.motion_model_path))
    niter = os.path.basename(args.motion_model_path).replace('model', '').replace('.pt', '')
    max_frames = 196 if args.dataset in ['kit', 'humanml', 'CelebVText'] else 60
    fps = 12.5 if args.dataset == 'kit' else 20
    n_frames = min(max_frames, int(args.motion_length*fps))
    is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name])
    dist_util.setup_dist(args.device)
    if out_path == '':
        out_path = os.path.join(os.path.dirname(args.motion_model_path),
                                'samples_{}_{}_seed{}'.format(args.motion_length, niter, args.seed))
        if args.text_prompt != '':
            out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '')
        elif args.input_text != '':
            out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '')

    # this block must be called BEFORE the dataset is loaded
    if args.input_text == '':
        exit()
    
    if args.input_text != '':
        assert os.path.exists(args.input_text)
        with open(args.input_text, 'r') as fr:
            lines = fr.readlines()

        motion_texts = []
        emotion_texts = []

        for line in lines:
            line = line.strip()
            if '#' in line:
                motion_text, emotion_text = line.split('#', 1)
                motion_texts.append(motion_text)
                emotion_texts.append(emotion_text)
            else:
                motion_texts.append(line)
                emotion_texts.append('')  # 如果没有情绪描述，则设置为空
        
        assert len(motion_texts) == len(emotion_texts)
        args.num_samples = len(motion_texts)

    assert args.num_samples <= args.batch_size, \
        f'Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})'
    # So why do we need this check? In order to protect GPU from a memory overload in the following line.
    # If your GPU can handle batch size larger then default, you can specify it through --batch_size flag.
    # If it doesn't, and you still want to sample more prompts, run this script with different seeds
    # (specify through the --seed flag)
    args.batch_size = args.num_samples  # Sampling a single batch from the testset, with exactly args.num_samples
    total_num_samples = args.num_samples * args.num_repetitions
    
    motion_args = copy.deepcopy(args)
    emotion_args = copy.deepcopy(args)
        
    print('Loading dataset...')
    motion_args.data_class = "motion"
    emotion_args.data_class = "emotion"
    
    motion_data = load_dataset(motion_args, max_frames, n_frames)
    emotion_data = load_dataset(emotion_args, max_frames, n_frames)
    
    
    print("Creating model and diffusion...")
    # new: load motion and emotion model seperately
    motion_model_path = args.motion_model_path if args.motion_model_path else args.model_path
    emotion_model_path = args.emotion_model_path if args.emotion_model_path else args.model_path

    motion_args.model_path = motion_model_path
    emotion_args.model_path = emotion_model_path
    motion_model, motion_diffusion = create_model_and_diffusion(motion_args, motion_data)
    emotion_model, emotion_diffusion = create_model_and_diffusion(emotion_args, emotion_data)

    print(f"Loading checkpoints from [{motion_args.model_path}]...")
    state_dict = torch.load(motion_args.model_path, map_location='cpu')
    load_model_wo_clip(motion_model, state_dict)
    
    print(f"Loading checkpoints from [{emotion_args.model_path}]...")
    state_dict = torch.load(emotion_args.model_path, map_location='cpu')
    load_model_wo_clip(emotion_model, state_dict)
    
    if args.guidance_param != 1:
        print(args.guidance_param)
        motion_model = ClassifierFreeSampleModel(motion_model)   # wrapping model with the classifier-free sampler
        emotion_model = ClassifierFreeSampleModel(emotion_model)   # wrapping model with the classifier-free sampler
        
    motion_model.to(dist_util.dev())
    motion_model.eval()  # disable random masking
    
    emotion_model.to(dist_util.dev())
    emotion_model.eval()  # disable random masking

    if is_using_data:
        iterator = iter(motion_data)
        _, model_kwargs = next(iterator)
    else:
        print(f'n_frames: {n_frames}')
        collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames}] * args.num_samples
        is_t2m = any([args.input_text, args.text_prompt])
        
        if is_t2m:
            # t2m
            collate_args_motion  = [dict(arg, text=txt) for arg, txt in zip(collate_args, motion_texts)]
            collate_args_emotion = [dict(arg, text=txt) for arg, txt in zip(collate_args, emotion_texts)]
        
        
        _, model_kwargs_motion = collate(collate_args_motion)
        _, model_kwargs_emotion = collate(collate_args_emotion)

    all_motions = []
    all_emotions = []
    all_lengths = []
    all_text = []

    for rep_i in range(args.num_repetitions):
        print(f'### Sampling [repetitions #{rep_i}]')

        # add CFG scale to batch
        if args.guidance_param != 1:
            model_kwargs_motion['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
        
        # motion generation
        sample_fn_motion = motion_diffusion.p_sample_loop
        motion_sample = sample_fn_motion(
            motion_model,
            (args.batch_size, motion_model.njoints, motion_model.nfeats, n_frames),
            clip_denoised=False,
            model_kwargs=model_kwargs_motion,
            skip_timesteps=0,
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )
        print(f'motion sample shape: {motion_sample.shape}')

        if motion_model.data_rep == 'hml_vec':
            # print(data.dataset.t2m_dataset.mean, data.dataset.t2m_dataset.std)
            motion_sample = motion_data.dataset.t2m_dataset.inv_transform(motion_sample.cpu().permute(0, 2, 3, 1)).float() # [1, 1, 196, 6]
            motion_sample = motion_sample.squeeze(1)

        # emotion generation
        
        # add CFG scale to batch
        if args.guidance_param != 1:
            model_kwargs_emotion['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
        
        sample_fn_emotion = emotion_diffusion.p_sample_loop
        emotion_sample = sample_fn_emotion(
            emotion_model,
            (args.batch_size, emotion_model.njoints, emotion_model.nfeats, n_frames),
            clip_denoised=False,
            model_kwargs=model_kwargs_emotion,
            skip_timesteps=0,
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )
        print(f'emotion sample shape: {emotion_sample.shape}')

        if emotion_model.data_rep == 'hml_vec':
            emotion_sample = emotion_data.dataset.t2m_dataset.inv_transform(emotion_sample.cpu().permute(0, 2, 3, 1)).float()
            emotion_sample = emotion_sample.squeeze(1)
        
        if args.unconstrained:
            all_text += ['unconstrained'] * args.num_samples
        else:
            text_key = 'text'
            all_text += [f"{motion_text}{emotion_text}" for motion_text, emotion_text in \
                         zip(model_kwargs_motion['y'][text_key], model_kwargs_emotion['y'][text_key])]

        all_motions.append(motion_sample.cpu().numpy())
        all_emotions.append(emotion_sample.cpu().numpy())
        all_lengths.append(model_kwargs_motion['y']['lengths'].cpu().numpy())

        print(f"created {len(all_motions) * args.batch_size} motion samples and {len(all_emotions) * args.batch_size} emotion samples")

    all_motions = np.concatenate(all_motions, axis=0) # (N, c1, 1, f)
    all_emotions = np.concatenate(all_emotions, axis=0) # (N, c2, 1, f)
    all_motions = np.concatenate((all_emotions, all_motions), axis=2) # use all_motions to represent from now on
    print(f"all_motions: {all_motions.shape}")
    
    all_motions = all_motions[:total_num_samples]
    all_text = all_text[:total_num_samples]
    all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples]

    if os.path.exists(out_path):
        shutil.rmtree(out_path)
    os.makedirs(out_path)

    npy_path = os.path.join(out_path, 'results.npy')
    print(f"saving results file to [{npy_path}]")
    np.save(npy_path,
            {'motion': all_motions, 'text': all_text, 'lengths': all_lengths,
             'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
    with open(npy_path.replace('.npy', '.txt'), 'w') as fw:
        fw.write('\n'.join(all_text))
    with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw:
        fw.write('\n'.join([str(l) for l in all_lengths]))
        
    abs_path = os.path.abspath(out_path)
    print(f'[Done] Results are at [{abs_path}]')


def load_dataset(args, max_frames, n_frames):
    data = get_dataset_loader(name=args.dataset,
                              batch_size=args.batch_size,
                              num_frames=max_frames,
                              split='test',
                              hml_mode='text_only',
                              args=args)
    if args.dataset in ['kit', 'humanml', 'CelebVText']:
        data.dataset.t2m_dataset.fixed_length = n_frames
    return data


if __name__ == "__main__":
    main()
