import os
import copy
import torch
import librosa
import argparse
import numpy as np
from rich import print
import mediapipe as mp
import albumentations as A
import albumentations.pytorch
from models import create_model
from rich.progress import track
from transformers import Wav2Vec2Processor
from models.componments.wav2vec import Wav2Vec2Model
from options.test_feature_options import TestOptions as FeatureOptions
from utils.util import assign_attributes, parse_config, get_semantic_indices, gauss_smooth_list
from utils.visualizer import frames_to_video, semantic_meshs2figure, get_feature_image, denorm_verticies_via_headpose


mp_drawing_styles = mp.solutions.drawing_styles
mp_connections = mp.solutions.face_mesh_connections


def create_model_by_opt(opt):
    ret = create_model(opt)
    ret.setup(opt)
    ret.load_networks(opt.load_epoch)
    ret.eval()

    return ret


parser = argparse.ArgumentParser()
parser.add_argument('--config',          type=str, default='config\feat_HDTF\WRA_CathyMcMorrisRodgers1_000.yaml', help='[!] Override the below parameters from config')
parser.add_argument('--data_dir',        type=str, default='HDTF')
parser.add_argument('--test_person',     type=str, default='May	000')
parser.add_argument('--force_up',        default=True)
parser.add_argument('--driven_audios',   default='data/May-test/audio.wav,')
parser.add_argument('--out_dir',         type=str, default='test_results/HDTF-LSP/Render')
parser.add_argument('--a2f_ckpt_fp',     type=str, default='ckpts/Audio2FeatureVertices/best_Audio2Feature.pkl')
parser.add_argument('--refiner_ckpt_fp', type=str, default='ckpts/PointRefiner/Audio2FeatureVertices/best_PointRefiner_G.pkl')


args = parser.parse_args()
opts = parse_config(args.config)
assign_attributes(opts, args)


test_opt = FeatureOptions().parse()
test_opt.load_epoch = args.a2f_ckpt_fp
test_opt.extract_wav2vec = True
test_opt.body = 'dualv2'
test_opt.model = 'motp' 
test_opt.subject_head = 'point'

blank_canvas = np.zeros((512, 512, 3), dtype=np.uint8)

refiner_opt = copy.deepcopy(test_opt)
refiner_opt.model = 'pointrefiner'
refiner_opt.load_epoch = args.refiner_ckpt_fp
refiner_opt.subject_head = 'point'
n_subjects = 12

all_semantic_indices = get_semantic_indices()

semantic_eye_indices = []
semantic_lip_indices = []
non_semantic_indices = []
for k, v in all_semantic_indices.items():
    if 'Left' in k or 'Right' in k:
        semantic_eye_indices += v
    elif 'Lip' in k:
        semantic_lip_indices += v
    else:
        non_semantic_indices += v

non_semantic_indices = [x for x in non_semantic_indices if x not in semantic_eye_indices and x not in semantic_lip_indices]
semantic_eye_indices = [x for x in semantic_eye_indices if x not in (468, 473)]     # remove L/R iris center points
eye_vertices_len = len(semantic_eye_indices)
lip_vertices_len = len(semantic_lip_indices)

sub_name, sub_id = args.test_person.split('\t')

test_opt.lip_vertice_dim = lip_vertices_len * 3
test_opt.eye_vertice_dim = eye_vertices_len * 3

Audio2Feature = create_model_by_opt(test_opt)
PointRefiner = create_model_by_opt(refiner_opt)

os.makedirs(args.out_dir, exist_ok=True)
meta = np.load(os.path.join(args.data_dir, sub_name, 'feature.npz'))

template = meta['Vertices']

template_mean = np.mean(template, axis=0)

template_eye = template[:, semantic_eye_indices, :]
template_lip = template[:, semantic_lip_indices, :]
template_eye_mean = np.mean(template_eye, axis=0)
template_lip_mean = np.mean(template_lip, axis=0)

headmotion = np.concatenate([meta['Headposes'], meta['Transposes'], meta['Scales'][..., None]], axis=-1)
headmotion_mean = np.mean(headmotion, axis=0)

shoulder_pts = np.load(os.path.join(args.data_dir, sub_name, 'shoulder.npy'))
shoulder_pts_mean = np.mean(shoulder_pts, axis=0)

if type(args.driven_audios) is str:
    driven_audios = args.driven_audios.split(',')
else:
    driven_audios = args.driven_audios

for driven_audio in driven_audios:

    out_name = f'{"_".join(driven_audio.split(".")[0].split("/")[-2:])}'

    audio_feature_fp = os.path.join(args.out_dir, out_name + '.npy')
    if not os.path.exists(audio_feature_fp):
        wav_path = os.path.join(driven_audio)
        sample_rate = 16000
        fps = 30
        speech_array, _ = librosa.load(os.path.join(wav_path), sr=sample_rate)
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        # wav2vec 2.0 weights initialization
        audio_encoder.feature_extractor._freeze_parameters()
        audio_encoder = audio_encoder.cuda()
        audio_encoder.eval()
        
        audio_feature = np.squeeze(processor(speech_array, sampling_rate=sample_rate).input_values) #[:16000*20]
        audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0]))
        
        audio_feature = torch.FloatTensor(audio_feature).cuda()

        step = 320000
        win_size = step*2
        predictions = []
        
        i = 0
        with torch.no_grad():
            while True:
                # print('Current state:', i*step, i*step+win_size, audio_feature.shape[1])
                audio_feature_clip = audio_feature[:, i*step: i*step+win_size]
                frame_num = int(audio_feature_clip.shape[1] / sample_rate * fps)

                prediction = audio_encoder(audio_feature_clip, frame_num=frame_num).last_hidden_state
                prediction = prediction.squeeze() # (seq_len, V*3)
                predictions.append(prediction)
                if i*step+win_size > audio_feature.shape[1]:
                    break
                i+=1
        
        if len(predictions) > 1:
            # merge predictions
            prediction = predictions[0]
            mid_len = prediction.shape[0] // 2
            for i in range(len(predictions) - 1):
                next_prediction = predictions[i+1]
                prediction[-mid_len:, :] = (prediction[-mid_len:, :] + next_prediction[:mid_len, :]) / 2.
                prediction = torch.cat([prediction, next_prediction[mid_len:, :]], dim=0)

        else:
            prediction = predictions[0]

        hidden_states = np.squeeze(prediction.detach().cpu().numpy())
        
        audio_feature = hidden_states
        np.save(audio_feature_fp, audio_feature)
    else:
        audio_feature = np.load(audio_feature_fp)
    audio_feature = torch.FloatTensor(audio_feature).cuda().unsqueeze(0)

    av_rate = audio_feature.shape[0] / template.shape[0]

    t_out_dir = os.path.join(args.out_dir, sub_name)
    os.makedirs(t_out_dir, exist_ok=True)

    sub_info_aud = int(sub_id) if test_opt.subject_head == 'onehot' else template_mean
    sub_info_rfn = int(sub_id) if refiner_opt.subject_head == 'onehot' else template_mean

    print('==== Generare feautre sequence for', out_name, '====')
    seq_fp = os.path.join(t_out_dir, f'seq_pred-{out_name}.npz')
    if not os.path.exists(seq_fp) or args.force_up:
        preds_lip, preds_eye, preds_head, preds_torso = Audio2Feature.generate_sequences(audio_feature, n_subjects=n_subjects, sub_id=sub_info_aud, av_rate=av_rate)
        np.savez(seq_fp, lip=preds_lip, eye=preds_eye, hp=preds_head, shoulder=preds_torso)
    else:
        prediction = np.load(seq_fp)
        preds_lip, preds_eye, preds_head, preds_torso = prediction['lip'], prediction['eye'], prediction['hp'], prediction['shoulder']
    
    print('==== Done. ====')

    # reshape vertices 3D
    preds_lip    = np.reshape(preds_lip,    (-1, preds_lip.shape[-1]//3, 3))
    preds_eye    = np.reshape(preds_eye,    (-1, preds_eye.shape[-1]//3, 3))
    preds_torso  = np.reshape(preds_torso,  (-1, preds_torso.shape[-1]//3, 3))

    # Here we enlarge the lip, head pose 
    preds_lip[..., 1] = preds_lip[..., 1] * 1.8
    preds_lip[..., 0] = preds_lip[..., 0] * 1.5
    preds_head[..., :3] = preds_head[..., :3] * 1.5

    if opts.with_smooth:
        preds_lip          = gauss_smooth_list(preds_lip, opts.smooth.lip)
        preds_eye          = gauss_smooth_list(preds_eye, opts.smooth.refine)
        preds_head[:, :3]  = gauss_smooth_list(preds_head[:, :3], opts.smooth.head_rot)
        preds_head[:, 3:6] = gauss_smooth_list(preds_head[:, 3:6], opts.smooth.head_trans)
        preds_torso        = gauss_smooth_list(preds_torso, opts.smooth.torso)
    
    t_pred_lip = torch.from_numpy(preds_lip).to(PointRefiner.device).unsqueeze(0)
    t_pred_eye = torch.from_numpy(preds_eye).to(PointRefiner.device).unsqueeze(0)

    preds_refine = PointRefiner.inference(t_pred_lip, t_pred_eye, n_subjects=n_subjects, sub_id=sub_info_rfn)
    preds_refine = np.reshape(preds_refine, (-1, preds_refine.shape[-1]//3, 3))

    if opts.with_smooth:
        preds_refine[:, non_semantic_indices] = gauss_smooth_list(preds_refine[:, non_semantic_indices], opts.smooth.others)

    frames_lst = []
    for i in track(range(preds_lip.shape[0]), description=f'Processing: {sub_name} => {out_name}, total: {preds_lip.shape[0]}'):

        t_refine_tmp    = preds_refine[i] + template_mean
        t_pred_head     = preds_head[i] + headmotion_mean
        t_pred_shoulder = preds_torso[i] + shoulder_pts_mean
        
        t_pred_refine = denorm_verticies_via_headpose(t_refine_tmp,  t_pred_head[:3], t_pred_head[3:6], t_pred_head[6:])
        t_shoulder    = denorm_verticies_via_headpose(t_pred_shoulder, t_pred_head[:3], t_pred_head[3:6], t_pred_head[6:])
        
        frame_pred_refine = semantic_meshs2figure(t_pred_refine, blank_canvas.copy())

        f_viz_pred = get_feature_image(t_pred_refine, t_shoulder, (512, 512))[0][..., ::-1]  # pred face + pred head + pred shoulder

        frame_pred_feature = f_viz_pred[..., ::-1]

        frames_lst.append(frame_pred_feature)
        

    frames_to_video(frames_lst, os.path.join(t_out_dir, f'{out_name}.mp4'), audio_fp=driven_audio, fps=30)
    print('Pred results are saved at: ', f'{t_out_dir}/{out_name}.mp4')

