from util import make_configuration, add_acoustic_config
from tqdm import tqdm
import argparse
import math
import os
import json
import shutil
import glob

import numpy as np
import random

import habitat
import habitat_sim
from habitat_sim.utils.common import quat_from_angle_axis

import soundfile as sf
from scipy.io import wavfile
from scipy.signal import fftconvolve
from util.trajectory import *

import sys
sys.path.append('..')


def normalize_audio(samples, desired_rms=0.1, eps=1e-4):
    # import pdb; pdb.set_trace()
    rms = np.maximum(eps, np.sqrt(np.mean(samples**2)))
    samples = samples * (desired_rms / rms)
    samples[samples > 1.] = 1.
    samples[samples < -1.] = -1.
    return samples 


def impulse_response_to_sound(binaural_rir, source_sound, sampling_rate):
    '''
        goal: create sound based on simulate impulse response
        binaural_rir: (num_sample, num_channel)
        source_sound: mono sound, (num_sample)
        rir and source sound should have same sampling rate
    '''
    binaural_convolved = np.array([fftconvolve(source_sound, binaural_rir[:, channel]) for channel in range(binaural_rir.shape[-1])])
    return binaural_convolved


def sample_audio_database(args, settings):
    if args.audio_database == 'FMA':
        database = glob.glob('Free-Music-Archive/ProcessedData/*/*/*.wav')
    elif args.audio_database == 'LibriSpeech':
        database = glob.glob('LibriSpeech/ProcessedData/*/*/*/audio.wav')
    database.sort()
    
    if args.fixed_ss_rendered_from_ss:
        database = database[:1]
        desired_rms = 0.03 * np.array([0.5]) + 0.07
        sampled_audio_paths = database*args.num_source
    else:
        sampled_audio_paths = np.random.choice(database, args.num_source, replace=False)
    source_sounds = []
    for audio_path in sampled_audio_paths:
        source_sound, _ = sf.read(audio_path, start=0, stop=int(settings['sample_rate'] * 10), dtype='float32', always_2d=True)
        source_sound = source_sound.mean(-1)
        if not args.fixed_ss_rendered_from_ss:
            desired_rms = 0.03 * np.random.rand() + 0.07
        source_sound = normalize_audio(source_sound, desired_rms=desired_rms)
        source_sounds.append(source_sound)
    return source_sounds


def calc_sound_direction_for_agent(pos_sound, pos_agent, angle_agent):
    '''
        Calculate sound direction respected to current agent's position and its orientation in 2D space
        where +Y is upward, -Z is forward and +X is rightward.
        rotation Left is + and right is - 
    '''
    pos_agent_2d = np.array([-pos_agent[2], -pos_agent[0]])
    pos_sound_2d = np.array([-pos_sound[2], -pos_sound[0]])
    vector_sound2agent = pos_sound_2d - pos_agent_2d
    if np.linalg.norm(vector_sound2agent) == 0:
        vector_sound2agent = np.array([1, 0])
    vector_sound2agent = vector_sound2agent / np.linalg.norm(vector_sound2agent)
    vector_agent = np.array([np.cos(angle_agent / 180 * np.pi), np.sin(angle_agent / 180 * np.pi)])
    dot_product = np.dot(vector_sound2agent, vector_agent)
    cross_product = np.cross(vector_agent, vector_sound2agent)
    angle = np.rad2deg(np.arccos(dot_product)) 
    if cross_product < 0:
        angle = - angle
    return angle

def verify_rir_volume(rir):
    rir_rms = np.sqrt(np.mean(rir ** 2))
    rms_th = 1e-6
    if rir_rms < rms_th:
        return False
    else:
        return True

def verify_unblocked_sound_source(sim, pos_source, pos_agent):
    # import pdb; pdb.set_trace()
    # euclidean distance
    euclidean_distance = np.sqrt(((pos_source - pos_agent) ** 2).sum())
    if euclidean_distance < 0.1:
        return False
    # geodesic distance
    path = habitat_sim.nav.ShortestPath()
    path.requested_start = pos_source
    path.requested_end = pos_agent
    found_path = sim.pathfinder.find_path(path)
    if not found_path:
        return False

    geodesic_distance = path.geodesic_distance
    status = np.isclose(geodesic_distance, euclidean_distance, atol=1e-3)
    return status


def sample_sound_source_location(args, sim, settings, receivers):
    '''
        Goal: to sample given number of sound source localization (x1, y1, z1) with receiver distance constraints
        each sound source location is fixed
    '''
    # import pdb; pdb.set_trace()
    n_source = args.num_source
    distance = args.max_distance
    sources = []
    count = 0

    center_pos = np.zeros(3)
    for i in range(len(receivers)):
        center_pos += receivers[i]['position']
    center_pos = center_pos / len(receivers)
    center_pos += np.array([0, settings["sensor_height"], 0])
    
    if args.fixed_ss_height:
        center_pos[1] = settings["sensor_height"]
        
    # import pdb; pdb.set_trace()
    maximum_try = 500 * n_source
    maximum_try_for_min_distance = 500 # TODO
    try_count = 0
    try_for_min_distance_count = 0
    while True:
        if count >= n_source or try_for_min_distance_count > maximum_try_for_min_distance:
            break
        pos = sim.pathfinder.get_random_navigable_point_near(circle_center=center_pos, radius=args.max_distance, max_tries=1000)
        
        if args.min_distance is not None:
            if np.linalg.norm(pos - center_pos) < args.min_distance:
                try_for_min_distance_count += 1
                continue
            
        pos = np.array(pos)
        status = 0
        for camera_id in range(len(receivers)):
            status += verify_unblocked_sound_source(sim, pos, receivers[i]['position'])
        
        if status < len(receivers):
            try_count += 1
            if try_count > maximum_try:
                break
            continue
        
        if args.fixed_ss_height:
            pos[1] = settings["sensor_height"]
        else:
            pos[1] = pos[1] + np.random.rand() * 1 + 0.7
        sources.append(np.array(pos))
        count += 1

    if len(sources) < n_source:
        return None
    return sources

# TODO
'''
sample camera location from a trajectory of sound source
'''
def sample_camera_location(args, sim, settings, audio_set):
    '''
        Goal: to sample given number of sound source localization (x1, y1, z1) with receiver distance constraints
        each sound source location is fixed
    '''
    n_source = args.num_source
    sources = []
    count = 0

    center_pos = np.zeros(3)
    for i in range(len(audio_set)):
        center_pos += audio_set[i]['position']
    center_pos = center_pos / len(audio_set)
    center_pos += np.array([0, settings["sensor_height"], 0])
    n_source_list = 100 * n_source
    maximum_try = 500 * n_source * 100
    
    try_count = 0
    while True:
        if count >= n_source_list:
            break
        pos = sim.pathfinder.get_random_navigable_point_near(circle_center=center_pos, radius=3, max_tries=1000)        
        pos = np.array(pos)
        if not sim.pathfinder.is_navigable(pos):
            continue
        
        status = 0
        for camera_id in range(len(audio_set)):
            status += verify_unblocked_sound_source(sim, audio_set[i]['position'], pos)
        
        if status < len(audio_set):
            try_count += 1
            if try_count > maximum_try:
                break
            continue
        
        # whether visible
        agent = sim.get_agent(0)
        new_state = sim.get_agent(0).get_state()
        new_state.position = pos
        
        # visible_flag = True
        # for audio_pos in audio_set:
        #     if not visible_flag:
        #         break
        #     source = audio_pos['position']
        #     angle = calc_sound_direction_for_agent(pos_sound=source, pos_agent=pos, angle_agent=0)
        #     new_state.rotation = quat_from_angle_axis(math.radians(angle), np.array([0, 1, 0]))
        #     # new_state.rotation = quat_from_angle_axis(math.radians(camera['angle']), np.array([0, 1, 0]))
        #     new_state.sensor_states = {}
        #     agent.set_state(new_state, True)
        #     audio_sensor = sim.get_agent(0)._sensors["audio_sensor"]
        #     # Set the audio source location (no need to set agent location, it's set implicitly)
        #     audio_sensor.setAudioSourceTransform(np.array(source))
        #     # Get sensor observations
        #     observation = sim.get_sensor_observations()
        #     # If the audio source is visible, continue to the next source
        #     if audio_sensor.sourceIsVisible():
        #         continue
        #     visible_flag=False
        # if not visible_flag:
        #     continue
        
        # pos[1] = pos[1] + np.random.rand() * 1 + 0.7 
        sources.append(np.array(pos))
        count += 1

    if len(sources) < n_source:
        return None
    return sources

def sample_sound_source_location_for_camera_set(args, scene_id, settings, camera_sets):
    '''
    Goal: to sample given number of sound source localization (x1, y1, z1) with receiver distance constraints for all camera_sets
    '''
    # Create a configuration based on arguments, scene ID, and settings
    cfg = make_configuration(args, scene_id, settings, add_semantic=False, visual_sensor=False)

    # Create a simulator instance
    sim = habitat_sim.Simulator(cfg)

    # Add acoustic configuration to the simulator
    sim = add_acoustic_config(sim, args, settings, indirect=False)

    # Set the seed for simulation
    sim.seed(settings['seed'])

    # Create an empty list to store camera sets with sound sources
    camera_sets_with_sound = []

    # Iterate over camera sets
    for set_ind, camera_set in tqdm(enumerate(camera_sets), total=len(camera_sets), desc='sampling sound source'):

        # Sample sound source locations for the current camera set
        sources = sample_sound_source_location(args, sim, settings, camera_set)

        # If no sources are found, skip to the next camera set
        if sources is None:
            continue

        # Initialize a flag to check for breaking out of loops
        break_flag = False

        # Iterate over cameras in the current camera set
        for camera_ind, camera in enumerate(camera_set):
            # Get the agent and set its state based on the camera's position and rotation
            agent = sim.get_agent(0)
            new_state = sim.get_agent(0).get_state()
            new_state.position = camera['position']
            new_state.rotation = quat_from_angle_axis(math.radians(camera['angle']), np.array([0, 1, 0]))
            new_state.sensor_states = {}
            agent.set_state(new_state, True)

            # If required, sample audio from a database
            if args.rir_to_sound:
                source_sounds = sample_audio_database(args, settings)

            # Iterate over sound sources
            for source_ind, source in enumerate(sources):
                # Get the audio sensor object
                audio_sensor = sim.get_agent(0)._sensors["audio_sensor"]
                # Set the audio source location (no need to set agent location, it's set implicitly)
                audio_sensor.setAudioSourceTransform(np.array(source))
                # Get sensor observations
                observation = sim.get_sensor_observations()
                # If the audio source is visible, continue to the next source
                if audio_sensor.sourceIsVisible():
                    continue
                # Set the break flag to True and break out of the loop
                break_flag = True
                break

            # If the break flag is set, break out of the current camera set loop
            if break_flag:
                break

        # If the break flag is set, skip to the next camera set
        if break_flag:
            continue

        # Store camera set information with sources
        camera_set_info = {
            'cameras': camera_set, # [{'positions','angles'}]
            'sources': sources,
        }

        # If required, add source sounds information
        if args.rir_to_sound:
            camera_set_info['source_sounds'] = source_sounds

        # Append camera set information to the list
        camera_sets_with_sound.append(camera_set_info)

    # Close the simulator
    sim.close()

    # Return the list of camera sets with sound sources
    return camera_sets_with_sound


def sound_position_pair_matching_for_motion(args,settings, cameras, i, j):
    return True

def sample_moving_sound_source_location(args, sim, settings):
    '''
        Goal: to sample given number of sound source localization (x1, y1, z1) with receiver distance constraints
        each sound source location is fixed
    '''    
    n_source = args.num_source
    maximum_try = 500 * n_source
    distance = args.distance
    n_trajectory_points = args.n_trajectory_points
    sources = []
    count = 0

    try_count = 0
        
    def positions_are_valid(sim,ss_points, reveicer_pos):
        for point in ss_points:
            if not verify_unblocked_sound_source(sim, point, reveicer_pos):
                return False
        return True
        
    while True:
        if count >= n_source:
            break
        receiver = sim.pathfinder.get_random_navigable_point(max_tries=500)
        center_pos = receiver
        center_pos += np.array([0, settings["sensor_height"], 0])
    
        start_pos = sim.pathfinder.get_random_navigable_point(max_tries=500)
        if not sim.pathfinder.is_navigable(start_pos):
            continue
        
        # TODO 
        end_pos = sim.pathfinder.get_random_navigable_point_near(circle_center=start_pos, radius=distance, max_tries=1000)
        end_pos[1] = start_pos[1]
        if not sim.pathfinder.is_navigable(end_pos):
            continue
        trajectory_positions_with_start_end = divide_line_segment(start_pos,end_pos,n_trajectory_points)  

        
        if not positions_are_valid(sim,trajectory_positions_with_start_end,receiver):
            try_count += 1
            if try_count > maximum_try:
                break
            continue
        
        trajectory_points = []
        for pos in trajectory_positions_with_start_end:
            # h = pos[1] + np.random.rand() * 1 + 0.7 # TODO??
            h = pos[1]
            trajectory_points.append(np.array([pos[0],h,pos[2]]))
        sources.append(np.array(trajectory_points))
        count += 1

    if len(sources) < n_source:
        return None
    return sources


def sample_sound_source_location_of_trajectory(args, scene_id, settings):
    '''
        Goal: to sample given number of camera pairs [(x1, y1, z1, angle1), (x2, y2, z2, angle2), ...] with overlap constraints and only rotation
    '''
    # import pdb; pdb.set_trace()
    if args.dataset == 'replica': 
        semantic_scene_id = scene_id.replace('mesh.ply', 'habitat/mesh_semantic.ply')
    elif args.dataset in ['hm3d', 'gibson']: 
        semantic_scene_id = scene_id
    cfg = make_configuration(args, semantic_scene_id, settings, add_semantic=True)
    sim = habitat_sim.Simulator(cfg)

    # Add acoustic configuration to the simulator
    sim = add_acoustic_config(sim, args, settings, indirect=False)
    sim.seed(settings['seed'])
    
    n_set = args.num_per_scene
    n_pos = args.num_camera # better to be 100 or 200, n_trajectory
    n_angle = args.num_angle
    n_set_per_pos = int(n_set // n_pos) if n_set >= n_pos else 1
    n_view = args.num_view
    final_audio_sets = []
    
    if args.debug:
        # n_pos = 1
        pass
        
    success_count = 0
    # for i in range(n_pos):
    while success_count < n_pos:
        audio_positions = []
        
        # TODO 
        '''
        first for easyness, linear uniform motion
        '''
        
        # find a random position
        
        # np_path = '/home/anmin/code/tmp/tmp/slfm_camera_sets_example.npy'
        # np_array = np.load(np_path,allow_pickle=True)
        # start_pos = np_array[0]['sources'][0];start_pos = np_array[0]['cameras'][0]['position']
        # end_pos = np_array[0]['sources'][0]
        
        trajectory_positions_with_start_end = sample_moving_sound_source_location(args, sim, settings)
        if trajectory_positions_with_start_end is None:
            continue
        success_count += 1
        audio_positions = trajectory_positions_with_start_end[0] # TODO: only use the first trajectory

        audio_sets = []
        # This sampling process ensure other views are correlated to the first view
        audio_set_dict = {}
        for ind in range(1):
            audio_set_dict[ind] = []
            for jnd in range(len(audio_positions)):
                if sound_position_pair_matching_for_motion(args, settings, audio_positions, ind, jnd):
                    audio_set_dict[ind].append(jnd)

        # create audio set from audio set dictinatory
        for ind in range(1):
            if len(audio_set_dict[ind]) < n_view - 1:
                continue
            
            import itertools
            # TODO: set n_views = len(audio_set_dict[ind]) first
            n_view = len(audio_set_dict[ind])
            # for c in itertools.combinations(audio_set_dict[ind], n_view - 1):
            for tmp in range(1):
                c = range(ind+1,n_view)
                audio_set = []
                audio_set.append({key: audio_positions[ind] for key in ['position']})
                for jnd in c:
                    audio_set.append({key: audio_positions[jnd] for key in ['position']})
                audio_sets.append(audio_set)

        # import pdb; pdb.set_trace()
        final_audio_sets += audio_sets

    sim.close()
    return final_audio_sets