import csv
import glob
import io
import json
import librosa
import numpy as np
import os
import pickle
from PIL import Image
from PIL import ImageFilter
import random
import scipy
import soundfile as sf
from scipy.io import wavfile
from scipy.signal import fftconvolve
import time
from tqdm import tqdm
import glob
import cv2
import math
import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as transforms

import sys
sys.path.append('..')
from utils import sound, sourcesep
from data import * 



class wsSLbaseDataset(object):
    def __init__(self, args, pr, list_sample, split='train',dataset_mod='1'): # 0 for simulated,1 for real,2 for mixed
        # debug
        self.dataset_mod=dataset_mod
        
        self.debug = True
            
        self.pr = pr
        
        # if split!='train':
        #     self.pr.clip_length=1.27
        # split = 'train'
        self.args = args
        self.split = split
        self.seed = pr.seed
        self.online_render = args.online_render
        self.time_sync = args.time_sync
        self.not_load_audio = args.not_load_audio
        self.n_view = args.n_view
        
        self.n_source = args.n_ss_position # TODO
        # self.args.n_source = 1
        # save args parameter
        self.repeat = args.repeat if split == 'train' else 1
        self.max_sample = args.max_sample if split in ['train', 'test'] else -1

        self.image_transform = transforms.Compose(self.generate_image_transform(args, pr))

        self.list_sample = self.get_list_sample(list_sample,mod=self.dataset_mod)
            
        # realworld dataset
        if self.max_sample > 0: 
            self.list_sample = self.list_sample[0:self.max_sample]

        self.list_sample = self.list_sample * self.repeat

        # init random seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Random Generator
        self.rng = np.random.default_rng(self.seed)

        num_sample = len(self.list_sample)
        if self.split == 'train':
            random.shuffle(self.list_sample)
        
        if self.online_render and dataset_mod != '7':
            self.audio_database = self.generate_audio_database()

        # import pdb; pdb.set_trace()
        # always load training relative angle distribution for angle bins
        # self.angle_distribution = self.relative_angle_distribution()
        if dataset_mod == '0':
            self.source_distribution = self.source_direction_distribution()

        print('Audio Dataloader: # sample of {}: {}'.format(self.split, num_sample))

    def getitem_ws_mot_dataset(self, index):
        # if self.split == 'train':
        #     index=618
        info = self.list_sample[index]
        pair_path = info
        audio_path = os.path.join(pair_path, 'audio/audio.wav')
        meta_path = os.path.join(pair_path, 'meta.json')
        with open(meta_path, "r") as f:
            meta_dict = json.load(f)

        if meta_dict['ifstereo'] is False:
            print("ifstereo is False:",info)
            raise NotImplementedError
            
        tgt_mot_id = int(meta_dict['top_id'][0]['id'])
        tgt_mot_results = [x for x in meta_dict['mot_result_list'] if x['id'] == tgt_mot_id]
        
        half_central=1280/2 # TODOs
        def select_clips_uniform_by_group(tgt_mot_results, audio_length,clip_length): 
            # input_list:[(central_point, central_time), ...]
            n_source = self.n_source
            groups = {}
            # group_size = 1280 / 6  # 
            group_size = 90/6

            l_bound_t = clip_length/2
            h_bound_t = audio_length - clip_length/2 - 0.4 # TODO
            l_bound_frame = l_bound_t*5 #debug
            h_bound_frame = h_bound_t*5
            filtered_list = [item for item in tgt_mot_results if l_bound_frame < item['frame'] < h_bound_frame]
            # 
            for item in filtered_list:
                # group_key = int(item['central_point'][0] / group_size)
                angle=(item['central_point'][0]-half_central)/half_central
                angle = torch.tensor(math.degrees(math.atan(angle)))
                group_key = int(angle / group_size)
                if group_key not in groups:
                    groups[group_key] = []
                groups[group_key].append(item)
                
            # 
            total_groups_num = [len(group) for group in groups.values()]
            sorted_index=sorted(range(len(total_groups_num)), key=lambda k: total_groups_num[k])
            ori_select_num_from_groups = [n_source // len(groups) + (1 if i < n_source % len(groups) else 0) for i in range(len(groups))]
            ori_select_num_from_groups.sort()
            select_num_from_groups=[0]*len(total_groups_num)
            for i in range(len(sorted_index)):
                select_num_from_groups[sorted_index[i]] = ori_select_num_from_groups[i]
                
            flag=0
            while flag<len(total_groups_num):
                if total_groups_num[sorted_index[flag]]<select_num_from_groups[sorted_index[flag]]:
                    supple=select_num_from_groups[sorted_index[flag]]-total_groups_num[sorted_index[flag]]
                    select_num_from_groups[sorted_index[flag]]=total_groups_num[sorted_index[flag]]
                    select_num_from_groups[sorted_index[len(sorted_index)-flag-1]]+=supple
                    flag+=1
                else:
                    break
             
            try:
                selected_values = [np.random.choice(groups[group], size=select_num_from_groups[i], replace=False) for i,group in enumerate(groups)]
                selected_values = [item for sublist in selected_values for item in sublist]
            except:
                pass
            
            # 改成算clip的长度
            selected_items = []
            def frame_to_time(frame_id):
                return frame_id / 5+ 1/10
            for item in selected_values:
                item['start'] = frame_to_time(item['frame']) - clip_length/2
                item['end'] = frame_to_time(item['frame']) + clip_length/2
                selected_items.append(item)

            if len(selected_items)!=self.n_source:
                print('@@@@@@@@@@@',index,'@@@@@@@@@@@@@')
                assert len(selected_items) == self.n_source
                # print("selected_items length is not correct")
            # 
            return selected_items     
        
        def select_clips_uniform_by_filtering(tgt_mot_results, audio_length,clip_length): 
            # input_list:[(central_point, central_time), ...]
            n_source = self.n_source
            groups = {}
            # group_size = 1280 / 6  # 
            group_size = 90/6

            l_bound_t = clip_length/2
            h_bound_t = audio_length - clip_length/2 - 0.4 # TODO
            l_bound_frame = l_bound_t*5 #debug
            h_bound_frame = h_bound_t*5
            filtered_list = [item for item in tgt_mot_results if l_bound_frame < item['frame'] < h_bound_frame]
            for item in filtered_list:
                item['r_angle'] = math.degrees(math.atan((item['central_point'][0]-half_central)/half_central))
            # part_1 = [item for item in filtered_list if abs(item['r_angle']) >= 15]
            # part_2 = [item for item in filtered_list if 10 <= abs(item['r_angle']) < 15]
            # part_2 = random.sample(part_2, int(len(part_2)/2))
            # part_3 = [item for item in filtered_list if abs(item['r_angle']) < 10]
            # part_3 = random.sample(part_3, int(len(part_3)/2))
            # filtered_list = part_1 + part_2 + part_3
            
            selected_values = np.random.choice(filtered_list, size=n_source, replace=True)
            
            # 
            selected_items = []
            def frame_to_time(frame_id):
                return frame_id / 5+ 1/10
            for item in selected_values:
                item['start'] = frame_to_time(item['frame']) - clip_length/2
                item['end'] = frame_to_time(item['frame']) + clip_length/2
                selected_items.append(item)

            if len(selected_items)!=self.n_source:
                print('elected_items length is not correct@@@@@@@@@@@',index,'@@@@@@@@@@@@@')
                assert len(selected_items) == self.n_source
                # print("selected_items length is not correct")

            return selected_items     
        
        def select_clips_random(tgt_mot_results, audio_length,clip_length): 
            # input_list:[(central_point, central_time), ...]
            n_source = self.n_source
            groups = {}
            # group_size = 1280 / 6  # 
            group_size = 90/6

            l_bound_t = clip_length/2
            h_bound_t = audio_length - clip_length/2 - 0.4 # TODO
            l_bound_frame = l_bound_t*5 #debug
            h_bound_frame = h_bound_t*5
            filtered_list = [item for item in tgt_mot_results if l_bound_frame < item['frame'] < h_bound_frame]
            for item in filtered_list:
                item['r_angle'] = math.degrees(math.atan((item['central_point'][0]-half_central)/half_central))
            selected_values = np.random.choice(filtered_list, size=n_source, replace=False)
            
            # 
            selected_items = []
            def frame_to_time(frame_id):
                return frame_id / 5+ 1/10
            for item in selected_values:
                item['start'] = frame_to_time(item['frame']) - clip_length/2
                item['end'] = frame_to_time(item['frame']) + clip_length/2
                selected_items.append(item)

            if len(selected_items)!=self.n_source:
                print('elected_items length is not correct@@@@@@@@@@@',index,'@@@@@@@@@@@@@')
                assert len(selected_items) == self.n_source
                # print("selected_items length is not correct")

            return selected_items           
        
        
        selected_items = select_clips_uniform_by_filtering(tgt_mot_results, audio_length=meta_dict['length'],clip_length=self.pr.clip_length)
        
        batch = {'pair_path': pair_path}
        audio, _ = self.read_audio(audio_path, start=0)
        audio = audio.T
            
        for ind in range(self.n_source):
            # print(ind,'/',self.n_source)
            start= int(selected_items[ind]['start'] * self.pr.samp_sr)
            audio_segment = audio[:, int(start): int(start + self.pr.clip_length * self.pr.samp_sr)]
            batch[f'audio_{ind+1}'] = audio_segment
            if audio_segment.shape[1] != self.pr.clip_length * self.pr.samp_sr:
                print('audio length is not correct')
            
            batch[f'audio_position_{ind+1}'] = torch.tensor(selected_items[ind]['r_angle'])
            
        # if batch['audio_1'].shape[1] != self.pr.clip_length * self.pr.samp_sr:
        #     print("audio length is not correct")
        #     raise NotImplementedError
        return batch

    def getitem_ws_dataset(self, index):
        # if self.split == 'train':
        #     index=618
        info = self.list_sample[index]
        pair_path = info
        audio_path = os.path.join(pair_path, 'audio/audio.wav')
        meta_path = os.path.join(pair_path, 'meta.json')
        with open(meta_path, "r") as f:
            meta_dict = json.load(f)

        if meta_dict['ifstereo'] is False:
            print("ifstereo is False:",info)
            raise NotImplementedError
            

        batch = {'pair_path': pair_path}
        audio, _ = self.read_audio(audio_path, start=0)
        audio = audio.T
        
        audio_list = []
        for i in range(self.n_source):
            audio_list.append(audio)

        def get_start_times(audio_length=int(meta_dict['length']),clip_length=self.pr.clip_length,n=self.n_source):
            random_start_times = np.floor(np.random.uniform(0, audio_length * self.pr.samp_sr - clip_length * self.pr.samp_sr, n))
            
            if self.n_source == 2:
                random_start_times = [1,audio_length * self.pr.samp_sr - clip_length * self.pr.samp_sr]
            return random_start_times
        start_times = get_start_times()
        
        def get_angle(start_times):
            if 'iphone13pro-stereo_clips_5' in audio_path:
                angle = int(meta_dict['u_id'].split('_')[1])
                angles = []
                for start_time in start_times:
                    angles.append(angle)
                return angles
            elif 'iphone1_clips_5' in audio_path:
                if self.split == 'train':
                    superglue_v6_path = audio_path.replace('audio/audio.wav','superglue_v6.json')
                    with open(superglue_v6_path, "r") as f:
                        superglue_v6_dict = json.load(f)
                    total_frames = int(meta_dict["length"])*5
                    sift_frames_gap = 5
                    tgt_sift = [[x,x+sift_frames_gap] for x in range(1,total_frames,sift_frames_gap) if x+sift_frames_gap< total_frames]
                    sift_value = superglue_v6_dict['sift']
                    for i,sift in enumerate(tgt_sift):
                        for value in sift_value:
                            if sift == value[0]:
                                tgt_sift[i] = value

                    for i,sift in enumerate(tgt_sift):
                        if abs(int(sift[2])) == 10000 and i != 0:
                            tgt_sift[i][2] = tgt_sift[i-1][2]

                    angles = []
                    for start_time in start_times:
                        start_time = start_time/self.pr.samp_sr # time rather than rate
                        angle = 0
                        for sift in tgt_sift:
                            if abs(int(sift[2])) == 10000:
                                continue
                            if start_time > sift[0][1]/sift_frames_gap:
                                angle -= sift[2]
                            elif start_time < sift[0][0]/sift_frames_gap:
                                break
                            else:
                                gap = (sift[0][1] - sift[0][0])/sift_frames_gap
                                angle -= sift[2]/gap*(start_time - sift[0][0]/sift_frames_gap)
                        angles.append(angle)
                    return angles
                else:
                    label_path = audio_path.replace('audio/audio.wav','label.json')
                    with open(label_path, "r") as f:
                        label_dict = json.load(f)
                    
                    # refresh start_times
                    labelled = []
                    for angle in label_dict['angle']:
                        try:
                            if abs(int(angle['angle'])) < 180:
                                labelled.append(angle)
                        except:
                            continue
                    
                    selected_angles = np.random.choice(labelled, size=self.n_source, replace=True)
                    start_times = [angle['start_time']*self.pr.clip_length for angle in selected_angles]
                    angles = [angle['angle'] for angle in selected_angles]
                    return angles
            else:
                total_frames = int(meta_dict["length"])*5
                tgt_sift = [[x,x+3] for x in range(1,total_frames,3) if x+3< total_frames]
                sift_value = meta_dict['sift']
                for i,sift in enumerate(tgt_sift):
                    for value in sift_value:
                        if sift == value[0]:
                            tgt_sift[i] = value

                for i,sift in enumerate(tgt_sift):
                    if abs(int(sift[2])) == 10000 and i != 0:
                        tgt_sift[i][2] = tgt_sift[i-1][2]

                angles = []
                for start_time in start_times:
                    start_time = start_time/self.pr.samp_sr
                    angle = 0
                    for sift in tgt_sift:
                        if abs(int(sift[2])) == 10000:
                            continue
                        if start_time > sift[0][1]:
                            angle += sift[2]
                        else:
                            gap = sift[0][1] - sift[0][0]
                            angle += sift[2]/gap*(start_time - sift[0][0])
                    angles.append(angle)
                return angles
        
        try:
            angles = get_angle(start_times)
        except:
            print('error in get_angle:',pair_path)
            print(meta_path)
            
            
        for i in range(len(audio_list)):
            audio_segment = audio_list[i][:, int(start_times[i]): int(start_times[i] + self.pr.clip_length * self.pr.samp_sr)]
            audio_list[i] = torch.tensor(audio_segment).float().unsqueeze(0)

        audio_list = torch.cat(audio_list, dim=0)
            
        for ind in range(self.n_source):
            # print(ind,'/',self.n_source)
            batch[f'audio_{ind+1}'] = audio_list[ind]
            if audio_list[ind].shape[1] != self.pr.clip_length * self.pr.samp_sr:
                print('audio length is not correct')
            
            batch[f'audio_position_{ind+1}'] = torch.tensor(angles[ind])
            
        # if batch['audio_1'].shape[1] != self.pr.clip_length * self.pr.samp_sr:
        #     print("audio length is not correct")
        #     raise NotImplementedError
        return batch
    
    def getitem_ws_avnerf_dataset(self, index):
        info = self.list_sample[index]        
        batch = {}

        def restrain_to_90(x):
            if x >= 0:
                return min(180 - x, x)
            else:
                return max(x, -180 - x)
            return x
                
        audio_list = []; source_relative_angle_list= []
        for ind in range(self.n_source):
            data = self.dataset[info['subset_id']-1][info['clip_id'][ind]]
            audio_segment = data['wav_bi']
            audio_segment = torch.tensor(audio_segment).float().unsqueeze(0)
            audio_list.append(audio_segment)
            if self.args.restrain_to_front:
                relative_angle = restrain_to_90(data['ori']*180)
            source_relative_angle_list.append(relative_angle)
        audio_list = torch.cat(audio_list, dim=0)
        
        for ind in range(self.n_source):
            # batch[f'img{ind+1}_path'] = img_path_list[ind]
            # batch[f'img_{ind+1}'] = img_list[ind]
            batch[f'audio_{ind+1}'] = audio_list[ind]
            # batch[f'angle_between_source1_camera{ind+1}'] = source_direction_list[ind][0]
            # batch[f'angle_bin_between_source1_camera{ind+1}'] = source_direction_list[ind][1]
            # batch[f'camera_{ind+1}_position'] = torch.tensor(camera_posit_list[ind])

        for source_ind in range(self.n_source):
            # batch[f'source_{source_ind+1}_position'] = torch.tensor(meta_dict[f'source_{source_ind}_position'])
            batch[f'source_{source_ind+1}_relative_angle'] = source_relative_angle_list[source_ind]
            # batch[f'angle_bin_between_source{source_ind+1}_camera1'] = source_relative_angle_list[source_ind][1]
            
        return batch
    
    
    def get_audio_segments_from_mot(self, meta_dict,n_source,clip_length):
        tgt_mot_id = int(meta_dict['top_id'][0]['id'])
        tgt_mot_results = [x for x in meta_dict['mot_result_list'] if x['id'] == tgt_mot_id]

    def __getitem__(self, index):
        # import pdb; pdb.set_trace()
        if self.dataset_mod == '2':
            return self.getitem_ws_dataset(index) # each clip
        elif self.dataset_mod == '1':
            return self.getitem_ws_mot_dataset(index)
        elif self.dataset_mod == '7': # avnerf dataset
            return self.getitem_ws_avnerf_dataset(index)
        
        info = self.list_sample[index]
        pair_path = info['path']
        
        # random pair_path
        overlap_pair_path = self.rng.choice(self.list_sample)['path']
        
        meta_path = os.path.join(pair_path, 'metadata.json')
        with open(meta_path, "r") as f:
            meta_dict = json.load(f)

        max_camera_num = np.array([key.find('camera') != -1 and key.find('position') != -1 for key in meta_dict.keys()]).sum()
        if self.n_view > max_camera_num:
            print("camera numbers are not match")
            raise NotImplementedError

        img_list = []
        img_path_list = []
        depth_path_list = []
        audio_list = []
        camera_angle_list = []
        camera_posit_list = []
        source_direction_list = []
        source_relative_angle_list = []

        if self.online_render:
            source_sounds = self.prepare_source_sounds(index=index)
        else:
            source_sounds = None
        
        camera_ind = 0
        
        for source_ind in range(self.n_source):
            img_path = os.path.join(pair_path, f'camera_0_rgb.png')
            img = self.read_image(img_path)
            img_list.append(img)
            camera_angle_list.append(meta_dict[f'camera_0_angle'])
            camera_posit_list.append(meta_dict[f'camera_0_position'])
            img_path_list.append(img_path)
            
            relative_angle = float(meta_dict[f"relative_angle_between_sound_{source_ind}_camera_0"])
            if self.not_load_audio:
                audio = np.zeros((2, np.rint(self.pr.clip_length * self.pr.samp_sr * 3).astype(int)))
            else:
                audio = self.generate_audio(index = index, 
                                            pair_path = pair_path,
                                            s_index = source_ind,  
                                            source_sounds = source_sounds,
                                            relative_angle = relative_angle, overlap_pair_path=overlap_pair_path)
                
            if self.args.restrain_to_front:
                def restrain_to_90(x):
                    if x >= 0:
                        return min(180 - x, x)
                    else:
                        return max(x, -180 - x)
                    return x
                relative_angle = restrain_to_90(relative_angle)
                
            audio_list.append(audio)
            source_direction_bin = self.calc_source_direction_to_bin(relative_angle)
            source_relative_angle_list.append((relative_angle,source_direction_bin))
            
        shuffle_flag = self.rng.random() > 0.5 if self.split == 'train' else False
        # shuffle_flag = False
        if shuffle_flag:
            shuffled_inds = self.rng.permutation(self.n_source)
            img_list = [img_list[i] for i in shuffled_inds]
            img_path_list = [img_path_list[i] for i in shuffled_inds]
            audio_list = [audio_list[i] for i in shuffled_inds]
            camera_angle_list = [camera_angle_list[i] for i in shuffled_inds]
            camera_posit_list = [camera_posit_list[i] for i in shuffled_inds]
            source_relative_angle_list = [source_relative_angle_list[i] for i in shuffled_inds]
            # source_direction_list = [source_direction_list[i] for i in shuffled_inds]

        fast_step = 0.5
        slow_step = 0.05
        slow_point = self.rng.choice(int(fast_step / slow_step + 1))
        if self.time_sync:
            fast_point = self.rng.choice(np.floor(self.pr.clip_length * 2 / fast_step).astype(int))
            fast_point = np.array([fast_point] * self.n_source)
        else:
            fast_point = self.rng.choice(np.floor(self.pr.clip_length * 2 / fast_step).astype(int), self.n_source, replace=False)
        
        start_times = (fast_step * fast_point + slow_step * slow_point) * self.pr.samp_sr
        for i in range(len(audio_list)):
            audio_segment = audio_list[i][:, int(start_times[i]): int(start_times[i] + self.pr.clip_length * self.pr.samp_sr)]
            audio_list[i] = torch.tensor(audio_segment).float().unsqueeze(0)

        audio_list = torch.cat(audio_list, dim=0)

        batch = {'pair_path': pair_path}
        # if self.n_source > 1:
        #     for i in range(1, self.n_source): 
        #         relative_camera_angle, relative_camera_angle_bin = self.calculate_relative_rotation([camera_angle_list[0], camera_angle_list[i]], return_bin=True)
        #         batch[f'relative_camera{i}_angle'] = relative_camera_angle
        #         batch[f'relative_camera{i}_angle_bin'] = relative_camera_angle_bin
        #         batch[f'relative_camera{i}_angle_sign'] = np.sign(relative_camera_angle)


        for ind in range(self.n_source):
            # batch[f'img{ind+1}_path'] = img_path_list[ind]
            # batch[f'img_{ind+1}'] = img_list[ind]
            batch[f'audio_{ind+1}'] = audio_list[ind]
            # batch[f'angle_between_source1_camera{ind+1}'] = source_direction_list[ind][0]
            # batch[f'angle_bin_between_source1_camera{ind+1}'] = source_direction_list[ind][1]
            batch[f'camera_{ind+1}_position'] = torch.tensor(camera_posit_list[ind])

        for source_ind in range(self.n_source):
            batch[f'source_{source_ind+1}_position'] = torch.tensor(meta_dict[f'source_{source_ind}_position'])
            batch[f'source_{source_ind+1}_relative_angle'] = source_relative_angle_list[source_ind][0]
            batch[f'angle_bin_between_source{source_ind+1}_camera1'] = source_relative_angle_list[source_ind][1]

        return batch


    def getitem_test(self, index):
        self.__getitem__(index)


    def __len__(self): 
        return len(self.list_sample)


    def get_list_sample(self, list_sample,mod='0'):
        if mod == '0':
            if isinstance(list_sample, str):
                samples = []
                csv_file = csv.DictReader(open(list_sample, 'r'), delimiter=',')
                for row in csv_file:
                    samples.append(row)
        elif mod in ['1','2']:
            samples = []
            for l_s in list_sample:
                with open(l_s, 'r') as file:
                    s = json.load(file)
                samples += s
        elif mod in ['7']: # avnerf
            from data.avnerf_dataset import RWAVSDataset
            end_subset_ind = 14
            self.dataset = [RWAVSDataset(data_root=f"code/others/RWAVS/release/{str(i)}", split= self.split) for i in range(1,end_subset_ind,1)]
            # set pairs here
            samples = []
            for i in range(1,end_subset_ind,1):
                len_of_subset = len(self.dataset[i-1])
                inds_of_subset = range(len_of_subset)
                # selected_pairs_num = int(len_of_subset*(len(self.dataset[i-1])-1)/20)
                # for _ in range(selected_pairs_num):
                #     pair = random.sample(inds_of_subset, 2)
                #     samples.append({
                #         'subset_id': i,
                #         'clip_id': pair})
                for ind_subset in inds_of_subset:
                    samples.append({
                        'subset_id': i,
                        'clip_id': (ind_subset,ind_subset)})
                    
        return samples 
    

    def generate_audio_database(self):
        audiobase_csv = f'{self.pr.audiobase_path}/{self.split}.csv'
        audio_database = self.get_list_sample(audiobase_csv)
        return audio_database


    def read_audio(self, audio_path, start=0, stop=None):
        # import pdb; pdb.set_trace()
        audio, audio_rate = sf.read(audio_path, start=start, stop=stop, dtype='float32', always_2d=True)
        # repeat in case audio is too short
        if not stop == None:
            desired_audio_length = int(stop - start)
            if audio.shape[0] < desired_audio_length:
                repeat_times = np.ceil(desired_audio_length / audio.shape[0])
                audio = np.tile(audio, (int(repeat_times), 1))[:desired_audio_length, :]

        if audio_rate != self.pr.samp_sr:
            audio = scipy.signal.resample(audio, int(audio.shape[0] / audio_rate * self.pr.samp_sr), axis=0)
            audio_rate = self.pr.samp_sr
        return audio, audio_rate
    

    def prepare_source_sounds(self,**kwargs):
        '''
            We preload the sound source to ensure the same sound track for each rir
            If with dominant sound option, we set the first sound source as the dominant one
        '''
        dominant_rms = None; audio_database = [self.audio_database[kwargs['index']]]
        source_sound_paths = self.rng.choice(audio_database, self.args.n_source, replace=False)
        source_sounds = []
        for source_ind in range(self.args.n_source):
            source_sound_path = os.path.join(source_sound_paths[source_ind]['path'], 'audio.wav')
            with open(os.path.join(source_sound_paths[source_ind]['path'], 'meta.json'), "r") as f:
                source_meta = json.load(f)
            audio_rate = source_meta['audio_sample_rate']
            audio_length = source_meta['audio_length']
            clip_length = np.rint(self.pr.clip_length * audio_rate * 3).astype(int)
            remain_length = int(audio_length - self.pr.clip_length * 3)
            if self.split == 'train' and remain_length > 0:
                start = int(self.rng.choice(remain_length) * audio_rate)
            else:
                start = 0
            source_sound, _ = self.read_audio(source_sound_path, start=start, stop=start+clip_length)
            source_sound = source_sound.mean(-1)
            if self.args.with_dominant_sound and not self.args.ssl_flag:
                if source_ind == 0:
                    dominant_rms = desired_rms = 0.06 * self.rng.random() + 0.08
                else:
                    snr = self.rng.integers(low=0, high=40)
                    desired_rms = np.sqrt(dominant_rms ** 2 / 10 ** (snr / 10.0))
            else:
                desired_rms = 0.06 * self.rng.random() + 0.07
            source_sound = self.normalize_audio(source_sound, desired_rms=desired_rms)
            source_sounds.append(source_sound)
        return source_sounds


    def obtain_binaural_rir_with_reverb(self, pair_path, camera_ind, source_ind):
        if self.args.indirect_ratio is None:
        # load binaural rir
            binaural_rir_path = os.path.join(pair_path, 'binaural_rirs', f'sound_{source_ind}_camera_{camera_ind}_rir.wav')
            binaural_rir, _ = sf.read(binaural_rir_path, dtype='float32', always_2d=True)
        else:
            direct_rir, _ = sf.read(os.path.join(pair_path, 'binaural_rirs_direct', f'sound_{source_ind}_camera_{camera_ind}_rir.wav'), dtype='float32', always_2d=True)
            indirect_rir, _  = sf.read(os.path.join(pair_path, 'binaural_rirs_indirect', f'sound_{source_ind}_camera_{camera_ind}_rir.wav'), dtype='float32', always_2d=True)
            zero_padding = np.zeros((indirect_rir.shape[0] - direct_rir.shape[0], direct_rir.shape[1]))
            direct_rir = np.concatenate((direct_rir, zero_padding), axis=0)
            binaural_rir = direct_rir + indirect_rir * self.args.indirect_ratio
        return binaural_rir


    def generate_audio(self, index, pair_path, s_index, source_sounds,**kwargs):
        # import pdb; pdb.set_trace()
        audio = None
        dominant_rms = None
        camera_ind = 0
        for source_ind in range(self.args.n_source):
            if self.online_render:
                # load binaural rir
                binaural_rir = self.obtain_binaural_rir_with_reverb(pair_path, camera_ind, s_index)
                source_sound = source_sounds[source_ind]
                render_audio = self.impulse_response_to_sound(binaural_rir, source_sound)
            else:
                render_audio_path = os.path.join(pair_path, 'render_audios', f'sound_{s_index}_camera_{camera_ind}_audio.wav')
                render_audio, _ = self.read_audio(render_audio_path, start=0, stop=np.rint(self.pr.clip_length * self.pr.samp_sr * 3).astype(int))
                render_audio = render_audio.T
            
            if self.args.with_dominant_sound and self.args.ssl_flag:
                if source_ind == 0:
                    dominant_rms = desired_rms = np.sqrt(np.mean(render_audio**2))
                else:
                    snr = self.args.dominant_snr if self.args.dominant_snr else self.rng.integers(low=5, high=30)
                    desired_rms = np.sqrt(dominant_rms ** 2 / 10 ** (snr / 10.0))
                render_audio = self.normalize_audio(render_audio, desired_rms=desired_rms)

            if audio is None:
                audio = render_audio
            else:
                audio += render_audio   # shape as (C, L)
  
        def get_intermittent_time(sound,intermittent_clips = 5, intermittent_time_propotion = 0.5):
            # Generate (num_elements - 1) random numbers
            # get random start, range(0,sound.shape[1])
            n = int(intermittent_time_propotion * sound.shape[1])
            intermittent_starts = random.sample(range(0, sound.shape[1]), intermittent_clips)
            random_numbers = sorted([random.randint(0, n) for _ in range(intermittent_clips - 1)])

            # Include the start and end points
            random_numbers = [0] + random_numbers + [n]

            # Calculate the differences between consecutive elements
            result = [random_numbers[i + 1] - random_numbers[i] for i in range(intermittent_clips)]
            mask = np.ones_like(sound, dtype=bool)
            for ind_intermittent in range(intermittent_clips):
                s_intermittent = intermittent_starts[ind_intermittent]
                e_intermittent = min(s_intermittent + result[ind_intermittent], sound.shape[1])
                mask[s_intermittent:e_intermittent] = False
            i_new_sound = sound * mask
            return i_new_sound
        
        if self.args.add_intermittent:
            audio = get_intermittent_time(audio)
                    
        if self.args.add_overlap:
            overlap_pair_path = kwargs['overlap_pair_path']
            binaural_rir = self.obtain_binaural_rir_with_reverb(overlap_pair_path, camera_ind, s_index)
            def get_new_audio(sound):
                # make an array to 30 clips
                audio_clips = []
                clip_num = 20
                assert len(sound)%clip_num == 0
                # get random index 
                numbers = list(range(0, clip_num))
                random.shuffle(numbers)
                clip_len = len(sound)//clip_num
                
                for i in range(clip_num):
                    start = int(i * clip_len)
                    stop = int((i + 1) * clip_len)
                    audio_clips.append(sound[start:stop])
                    
                new_sound = np.concatenate([audio_clips[numbers[i]] for i in range(clip_num)], axis=0)
                # assert shape is same
                assert new_sound.shape == sound.shape
                # intermittent:
                return new_sound
            
            source_sound = get_new_audio(source_sounds[source_ind])
            sub_audio = self.impulse_response_to_sound(binaural_rir, source_sound)         
            
            if self.args.add_intermittent:
                sub_audio = get_intermittent_time(sub_audio)
            audio += sub_audio*0.7
               
        # may move this noise part to other place 
        if self.args.add_noise:
            audio = self.add_gaussian_noise_by_snr(audio)
        
        if self.args.save_audio:
            # import pdb; pdb.set_trace()
            save_folder = os.path.join('./checkpoints', self.args.exp, 'saved_audio')
            os.makedirs(save_folder, exist_ok=True)
            
            relative_angle = kwargs['relative_angle']
            save_path = os.path.join(save_folder, f'audio_{str(index).zfill(5)}_audio_{s_index}_{relative_angle}.wav')
            sf.write(save_path, audio.T, self.pr.samp_sr)

        return audio


    def normalize_audio(self, samples, desired_rms=0.1, eps=1e-4):
        rms = np.maximum(eps, np.sqrt(np.mean(samples**2)))
        samples = samples * (desired_rms / rms)
        samples[samples > 1.] = 1.
        samples[samples < -1.] = -1.
        return samples 


    def sum2audio(self, audio_1, audio_2):
        audio = audio_1 + audio_2
        audio[audio > 1.] = 1.
        audio[audio < -1.] = -1.
        return audio


    def impulse_response_to_sound(self, binaural_rir, source_sound):
        '''
            goal: create sound based on simulate impulse response
            binaural_rir: (num_sample, num_channel)
            source_sound: mono sound, (num_sample)
            rir and source sound should have same sampling rate
        '''
        # import pdb; pdb.set_trace()
        audio_length = source_sound.shape[0]
        binaural_convolved = np.array([fftconvolve(source_sound, binaural_rir[:, channel]) for channel in range(binaural_rir.shape[-1])])
        binaural_convolved = binaural_convolved[:, :audio_length]
        return binaural_convolved


    def add_gaussian_noise_by_snr(self, signal, snr=None):
        # import pdb; pdb.set_trace()
        snr = snr if snr is not None else self.rng.integers(low=10, high=40)
        signal_rms = np.sqrt(np.mean(signal ** 2))
        if signal_rms == 0:
            return signal
        noise_rms = np.sqrt(signal_rms ** 2 / 10 ** (snr / 10.0))
        noise = self.rng.normal(loc=0.0, scale=noise_rms, size=signal.shape)
        audio = signal + noise
        audio[audio > 1.] = 1.
        audio[audio < -1.] = -1.
        return audio


    def read_image(self, img_path):
        image = Image.open(img_path).convert('RGB')
        image = self.image_transform(image)
        return image
    
    
    def generate_image_transform(self, args, pr):
        resize_funct = transforms.Resize(pr.img_size)
        vision_transform_list = [
            resize_funct,
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
        return vision_transform_list


    def relative_angle_distribution(self):
        list_sample = self.get_list_sample(self.pr.list_train)
        angle_distribution = []
        for i in range(len(list_sample)):
            info = list_sample[i]
            pair_path = info['path']
            meta_path = os.path.join(pair_path, 'metadata.json')
            with open(meta_path, "r") as f:
                meta_dict = json.load(f)
            camera_angle_list = [meta_dict['camera_0_angle'], meta_dict['camera_1_angle']]
            relative_angle = self.calculate_relative_rotation(camera_angle_list)
            angle_distribution.append(relative_angle)
            angle_distribution.append(-relative_angle)
        return np.array(angle_distribution)


    def calculate_relative_rotation(self, camera_angle_list, return_bin=False):
        '''
            We define turning left as +, turning right as -.
            relative_angle are within (-180, 180]
        '''
        relative_angle = camera_angle_list[1] - camera_angle_list[0]
        if relative_angle >= 180: 
            relative_angle = relative_angle - 360
        elif relative_angle < -180:
            relative_angle = 360 + relative_angle
        if not return_bin:
            return relative_angle
        else:
            offset = 1e-2
            angle_range = np.abs(self.angle_distribution).max() - np.abs(self.angle_distribution).min()
            bin_size = angle_range / (self.pr.num_classes // 2)
            if relative_angle >= 0:
                relative_angle_bin = (relative_angle - np.abs(self.angle_distribution).min()) // bin_size 
                relative_angle_bin = np.clip(relative_angle_bin, 0, self.pr.num_classes // 2 - 1) + self.pr.num_classes // 2
            elif relative_angle < 0:
                relative_angle_bin = (relative_angle + np.abs(self.angle_distribution).max()) // bin_size
                relative_angle_bin = np.clip(relative_angle_bin, 0, self.pr.num_classes // 2 - 1)
            return relative_angle, relative_angle_bin
    

    def source_direction_distribution(self):
        # import pdb; pdb.set_trace()
        list_sample = self.get_list_sample(self.pr.list_train)
        direction_distribution = []
        for i in range(len(list_sample)):
            info = list_sample[i]
            pair_path = info['path']
            meta_path = os.path.join(pair_path, 'metadata.json')
            with open(meta_path, "r") as f:
                meta_dict = json.load(f)
            for key in meta_dict.keys():
                if key.find('relative_angle_between_sound') != -1:
                    
                    direction_distribution.append(meta_dict[key])
    
        arr = np.array(direction_distribution)
        if self.args.restrain_to_front:
            arr = np.where(arr > 0, np.minimum(180 - arr, arr), np.maximum(arr, -180 - arr))    
        return arr


    # def calc_source_direction_to_bin(self, source_angle):
    #     '''
    #         We define turning left as +, turning right as -.
    #         source angle are within (-180, 180]
    #     '''
    #     angle_range = self.source_distribution.max() - self.source_distribution.min()
    #     bin_size = angle_range / self.pr.num_classes      
    #     source_angle_bin = (source_angle - self.source_distribution.min()) // bin_size 
    #     source_angle_bin = np.clip(source_angle_bin, 0, self.pr.num_classes - 1)
    #     return source_angle_bin
    
    def calc_source_direction_to_bin(self, source_angle):
        '''
            We define turning left as +, turning right as -.
            source angle are within (-180, 180]
        '''
        max = self.source_distribution.max()
        min = self.source_distribution.min()
        angle_range = max - min
        bin_size = angle_range / self.pr.num_classes    
        source_angle_bin = torch.div(source_angle - min, bin_size, rounding_mode='floor')
        source_angle_bin = np.clip(source_angle_bin, 0, self.pr.num_classes - 1)
        return source_angle_bin



