from cmath import sin
import os
import numpy as np
import json
import torch
from torch.utils.data import Dataset
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import math
import scipy.linalg as linalg
from torch.autograd import Variable


from scipy.fftpack import dct, idct

testset = ['all', 'courtyard_rangeOfMotions_00', 'courtyard_goodNews_00', 'courtyard_arguing_00', 'courtyard_shakeHands_00', 'courtyard_captureSelfies_00', 
            'courtyard_basketball_00', 'courtyard_giveDirections_00', 'courtyard_capoeira_00', 'courtyard_dancing_01', 'courtyard_warmWelcome_00']
trainset = ['all', 'downtown_rampAndStairs_00', 'downtown_bar_00', 'downtown_car_00', 'downtown_walking_00', 'downtown_cafe_00', 'downtown_runForBus_00', 
            'downtown_bus_00', 'downtown_runForBus_01', 'downtown_warmWelcome_00', 'downtown_crossStreets_00', 'downtown_sitOnStairs_00', 
            'office_phoneCall_00', 'downtown_arguing_00']
validset = ['all', 'courtyard_hug_00', 'courtyard_drinking_00', 'courtyard_rangeOfMotions_01', 'courtyard_dancing_00']


def batch_denormalization(data, para):
    '''
    :data: [B, T, N, J, 6] or [B, T, J, 3]
    :para: [B, 3]
    '''
    if data.shape[2]==2:
        data[..., :3] += para[:, None, None, None, :]
    else:
        data += para[:, None, None, :]
    return data

def normalize(data):
    '''
    Notice: without batch operation.
    '''
    mean_pose = np.mean(data[0, 0, :], axis=0)
    # shape: 3
    data = data - mean_pose[None, None, None, :]
    return data, mean_pose   

def rotate_Y(input, beta):
    '''
    beta: angle
    '''
    output = np.zeros_like(input)
    beta = beta * (np.pi / 180) # angle to radian
    output[:, :, :, 0] = np.cos(beta)*input[:, :, :, 0] + np.sin(beta)*input[:, :, :, 2]
    output[:, :, :, 2] = -np.sin(beta)*input[:, :, :, 0] + np.cos(beta)*input[:, :, :, 2]
    output[:, :, :, 1] = input[:, :, :, 1]
    return output

class SoMoFDataseta_3dpw(Dataset):
    def __init__(self, dset_path="somof_data_3dpw/", in_seq_len=16, entire_seq_len=30, split_name='train', data_augment=True, vel=True, permute=True):
        self.in_seq_len = in_seq_len
        self.entire_seq_len = entire_seq_len

        filename = "poseData_overlap.pkl"
        with open(os.path.join(dset_path, filename), 'rb') as input:
            oridata = pkl.load(input)['{}'.format(split_name)]
            if permute:
                p_oridata = np.roll(np.array(oridata), 1, axis=1)
                oridata = np.concatenate((oridata, p_oridata), axis=0)

        self.data = []
        self.data_para = []
        
        videoNumIn = len(oridata)
        agentsNum = 2
        timeStepsNum = 30
        jointsNum = 13
        coordsNum = 3  # x y z
        self.dim = 6 if vel else 3
        
        for i in range(videoNumIn):
            curr_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
            curr_data_para = np.zeros((7, 1))

      
            temp_data = np.array(oridata[i])
            temp_data = temp_data.reshape((agentsNum, timeStepsNum, jointsNum, coordsNum))
            temp_ = temp_data.copy()
            curr_data, curr_data_para = normalize(temp_data)  
            input = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
            if vel:
                vel_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
                vel_data[:,1:,:,:] = (np.roll(curr_data, -1, axis=1) - curr_data)[:,:-1,:,:]
                data = np.concatenate((curr_data, vel_data), axis=3)
            else:
                data = curr_data.copy()

            self.data.append(data)
            self.data_para.append(curr_data_para)

            if data_augment and split_name == 'train':
                # rotate
                rotate_data = rotate_Y(temp_, 120)
                rotate_data, rotate_data_para = normalize(rotate_data)
                if vel:
                    vel_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
                    vel_data[:,1:,:,:] = (np.roll(rotate_data, -1, axis=1) - rotate_data)[:,:-1,:,:]
                    data = np.concatenate((rotate_data, vel_data), axis=3)
                else:
                    data = rotate_data.copy()
                self.data.append(data)
                self.data_para.append(rotate_data_para)

                # reverse
                curr_data = np.flip(curr_data, axis=2)
                if vel:
                    vel_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
                    vel_data[:,1:,:,:] = (np.roll(curr_data, -1, axis=1) - curr_data)[:,:-1,:,:]
                    data = np.concatenate((curr_data, vel_data), axis=3)
                else:
                    data = curr_data.copy()
                self.data.append(data)
                self.data_para.append(curr_data_para)
        

    def __getitem__(self, idx: int):
        data = self.data[idx].transpose((1, 0, 2, 3)) 
        para = self.data_para[idx]
        return data, para
    
    def __len__(self):
        return len(self.data)

class SoMoFDataset_3dpw_test(Dataset):
    def __init__(self, dset_path="somof_data_3dpw/", in_seq_len=16, split_name='test'):
        self.in_seq_len = in_seq_len

        oridata = []
        jsfile = open(os.path.join(dset_path, "{}_in.json".format(split_name)), 'r')
        oridata_in = json.load(jsfile)
        for i in range(len(oridata_in)):
            oridata.append(oridata_in[i])
            oridata.append(np.roll(np.array(oridata_in[i]), 1, axis=0))
            
        jsfile.close()

        self.data = []
        self.data_para = []
        videoNumIn = len(oridata)
        agentsNum = 2
        timeStepsNum = in_seq_len
        jointsNum = 13
        coordsNum = 3  # x y z
        
        for i in range(videoNumIn):
            curr_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
            curr_data_para = np.zeros((7, 1))

            temp_data = np.array(oridata[i])  
            curr_data = temp_data.reshape((agentsNum, timeStepsNum, jointsNum, coordsNum))
            curr_data, curr_data_para = normalize(curr_data) 
            vel_data = np.zeros((agentsNum, timeStepsNum, jointsNum, coordsNum))
            vel_data[:,1:,:,:] = (np.roll(curr_data, -1, axis=1) - curr_data)[:,:-1,:,:]
            data = np.concatenate((curr_data, vel_data), axis=3)
            self.data.append(data)
            self.data_para.append(curr_data_para)
    
    def __getitem__(self, idx: int):
        data = self.data[idx].transpose((1, 0, 2, 3)) 
        para = self.data_para[idx]
        return data, para

    def __len__(self):
        return len(self.data)