import sys
sys.path.append(".")
import math
import os
import random
import numpy as np
import torch
import torch.nn as nn
import time
from torch import optim
from dataset import batch_denormalization, SoMoFDataset_3dpw_test
from model import JRTransformer

from torch.utils.data import DataLoader, SequentialSampler, DataLoader

from metrics import batch_MPJPE, batch_VIM
import argparse

from torch.utils.tensorboard import SummaryWriter
from transformers.integrations import TensorBoardCallback
from datetime import datetime

def get_adj():
	edges = [(0, 1), (1, 8), (8, 7), (7, 0),
			 (0, 2), (2, 4),
			 (1, 3), (3, 5),
			 (7, 9), (9, 11),
			 (8, 10), (10, 12),
			 (6, 7), (6, 8)]
	adj = np.eye(2*13)
	for edge in edges:
		adj[edge[0], edge[1]] = 1
		adj[edge[1], edge[0]] = 1
		adj[edge[0]+13, edge[1]+13] = 1
		adj[edge[1]+13, edge[0]+13] = 1
	return torch.from_numpy(adj).float().cuda()

def get_connect():
	conn = np.zeros((2*13, 2*13))
	conn[:13, :13] = 1
	conn[13:, 13:] = 1
	return torch.from_numpy(conn).float().cuda()

class Tester:
    def __init__(self, args):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(0)
        else:
            self.device = torch.device("cpu")
        print('Using device:', self.device)
        self.cuda_devices = args.cuda_devices

        self.path = args.path

        # Training parameters
        self.batch_size = args.batch_size

        # Defining models
        self.model = JRTransformer(feat_size=args.d_k,depth=args.depth).to(self.device)
        
        dset_test = SoMoFDataset_3dpw_test(in_seq_len=30)
        sampler_test = SequentialSampler(dset_test)
        self.test_loader = DataLoader(dset_test, sampler=sampler_test, batch_size=args.batch_size, num_workers=2, drop_last=False, pin_memory=True)
        
        self.joint_to_use = np.array([1, 2, 4, 5, 7, 8, 15, 16, 17, 18, 19, 20, 21])

        self.adj = get_adj()
        self.adj = self.adj.unsqueeze(0).unsqueeze(-1)
        self.conn = get_connect()
        self.conn = self.conn.unsqueeze(0).unsqueeze(-1)
    
    def process_pred(self, pred_vel_dct, pred_vel_aux_dct):
        pred_vel_x = pred_vel_dct[:, :, :16]
        pred_vel_y = pred_vel_dct[:, :, 16:]
        pred_vel_aux_x = []
        pred_vel_aux_y = []
        for pred_dct_ in pred_vel_aux_dct:
            pred_vel_aux_x.append(pred_dct_[:, :, :16])
            pred_vel_aux_y.append(pred_dct_[:, :, 16:])
        return pred_vel_x, pred_vel_y, pred_vel_aux_x, pred_vel_aux_y

        
    def test(self):
        checkpoint = torch.load(self.path)  
        self.model.load_state_dict(checkpoint['net']) 
        self.model.eval()

        all_mpjpe = np.zeros(5)
        all_vim = np.zeros(5)
        count = 0
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                input_total_original, para = data
                input_total_original = input_total_original.float().cuda()
                input_total = input_total_original.clone()

                batch_size = input_total.shape[0]
                T=30
                input_total[..., [1, 2]] = input_total[..., [2, 1]]
                input_total[..., [4, 5]] = input_total[..., [5, 4]]


                camera_vel = input_total[:, 1:30, :, :, 3:].mean(dim=(1, 2, 3)) # B, 3

                input_total[..., 3:] -= camera_vel[:, None, None, None]
                input_total[..., :3] = input_total[:, 0:1, :, :, :3] + input_total[..., 3:].cumsum(dim=1)

                input_total = input_total.permute(0, 2, 3, 1, 4).contiguous().view(batch_size, -1, 30, 6)
				# B, NxJ, T, 6

                tgt = input_total[:,:, :16]
				
                pos = input_total[:,:,:16,:3]
                pos_i = pos.unsqueeze(-3)
                pos_j = pos.unsqueeze(-4)
                pos_rel = pos_i - pos_j
                dis = torch.pow(pos_rel, 2).sum(-1)
                dis = torch.sqrt(dis)
                exp_dis = torch.exp(-dis)
                exp_dis = torch.cat((exp_dis, self.adj.repeat(batch_size, 1, 1, 1), self.conn.repeat(batch_size, 1, 1, 1)), dim=-1)

                pred_vel = self.model.predict(tgt, exp_dis)
                pred_vel = pred_vel[:, :, 16:]

				
                pred_vel = pred_vel.permute(0, 2, 1, 3)
                pred_vel = pred_vel + camera_vel[:, None, None]
				# B, T, NxJ, 3
                pred_vel[..., [1, 2]] = pred_vel[..., [2, 1]]
				# Cumsum velocity to position with initial pose.
                motion_gt = input_total_original[...,:3].view(batch_size, T, -1, 3)
                motion_pred = (pred_vel.cumsum(dim=1) + motion_gt[:, 15:16])
				
				# Apply denormalization.
                motion_pred = batch_denormalization(motion_pred.cpu(), para).numpy()               
                motion_gt = batch_denormalization(motion_gt.cpu(), para).numpy() 

                metric_MPJPE = batch_MPJPE(motion_gt[:, 16:, :13, :], motion_pred[:, :, :13, :])
                all_mpjpe += metric_MPJPE

                metric_VIM = batch_VIM(motion_gt[:, 16:, :13, :], motion_pred[:, :, :13, :])
                all_vim += metric_VIM
                
                count += batch_size

            all_mpjpe *= 100
            all_vim *= 100
            all_mpjpe /= count
            all_vim /= count
            print('Test MPJPE:\t avg: {:.2f}'.format(all_mpjpe.mean()))
            print('Test VIM:\t avg: {:.2f}'.format(all_vim.mean()))    
        return 

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--d_k', type=int, default=128)
    parser.add_argument('--depth', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=8)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--cuda_devices', type=str, default="1")
    parser.add_argument('--path', type=str, default="best.pt")

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices
    
    tester = Tester(args)
    tester.test()