import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
import torchaudio
import numpy as np

import sys
sys.path.append('..')
from config import init_args, params
import models
from models import FloatEmbeddingSine



class wsSLNet(nn.Module):
    # SLfM pretext net
    def __init__(self, args, pr, nets, dataset_mod = None):
        super(wsSLNet, self).__init__()
        self.pr = pr
        self.args = args
        self.n_view = args.n_view
        self.no_vision = args.no_vision
        self.no_cond_audio = args.no_cond_audio
        self.add_geometric = args.add_geometric
        self.use_gt_rotation = args.use_gt_rotation
        self.generative_loss_ratio = args.generative_loss_ratio
        if dataset_mod is None:
            self.dataset_mod = args.dataset_mod
        else:
            self.dataset_mod = dataset_mod
            
        self.audio_net = nets
        self.generative_net = models.__dict__[pr.generative_net](args, pr)
        
        self.azimuth_loss_type = args.azimuth_loss_type
        if self.azimuth_loss_type == 'classification':
            self.n_clf = args.n_ss_position_clf # TODO
        elif self.azimuth_loss_type == 'regression':
            self.n_clf = 1
        if args.restrain_to_front:
            self.angle_range_max = 90.0
        else:
            self.angle_range_max = 180.0

        self.angle_head = self.construct_feature_head(args, pr, n_clf = self.n_clf)
        self.loss_type = args.loss_type

        if self.use_gt_rotation:
            self.rota_embedding = FloatEmbeddingSine(num_pos_feats=pr.visual_feature_dim, scale=1)

        self.freeze_param(args)

    def set_dataset_mod(self, dataset_mod):
        self.dataset_mod = dataset_mod
        
    def forward(self, inputs, dataset_mod='0', loss=False, evaluate=False, inference=False, return_angle=False):
        # import pdb; pdb.set_trace()
        self.dataset_mod = dataset_mod
        augment = not (evaluate or inference)
        # cond_audios, audio_input, audio_output = self.generate_audio_pair(inputs)
        audio_input = self.generate_audio_pair(inputs)
        cond_audios = audio_input
        cond_feats = None
        # cond_feats = self.encode_conditional_feature(inputs, cond_audios, augment)
        # import pdb; pdb.set_trace()
        pred_audio = self.generative_net(audio_input, cond_feats)
        pred_angle = self.angle_head(pred_audio)

        if return_angle:
            pred_angle = self.calc_loss(inputs, pred_angle, dataset_mod=dataset_mod, evaluate=False,return_angle=return_angle)
            return pred_angle
        if loss:
            loss = self.calc_loss(inputs, pred_angle, dataset_mod=dataset_mod, evaluate=False)
            return loss
        if evaluate:
            output = self.calc_loss(inputs, pred_angle, dataset_mod=dataset_mod,evaluate=True)
            return output
        if inference:
            output = self.inference(inputs, pred_audio,dataset_mod=dataset_mod)
            return output
        
        return pred_angle
    
    def inverse_sound_direction(self, pred):
        pred_inv = torch.clone(pred)
        pred_inv[pred_inv <= 0.] = - 180 - pred_inv[pred_inv <= 0]
        pred_inv[pred_inv > 0.] = 180 - pred_inv[pred_inv > 0]
        return pred_inv

    def restrict_to_180(self, angle):
        angle_ = torch.clone(angle)
        angle_[abs(angle_) > 180] = np.sign(angle_[abs(angle_) > 180]) * (360 - abs(angle_[abs(angle_) > 180]))
        return angle_
    
    def trunk_rough_angle(self,angle, n_comp_clf = 6, return_degree = True):
        if self.dataset_mod=='0':   
            piece = 2 * self.angle_range_max / n_comp_clf
            angle = torch.div(angle, piece, rounding_mode='floor')*piece+piece/2
        elif self.dataset_mod=='1':
            piece = 180 / n_comp_clf # TODO
            angle = torch.div(angle, piece, rounding_mode='floor')*piece+piece/2
        if return_degree is False:
            angle = angle * torch.pi / 180
        return angle
                

    def restrict_to_180(self,angle):
        angle_ = torch.clone(angle)
        condition = abs(angle_) > 180
        sign = torch.where(angle_ > 0, torch.tensor(1).to(angle_.device), torch.tensor(-1).to(angle_.device))
        angle_[abs(angle_) > 180] = -1 * sign[condition] * (360 - abs(angle_[abs(angle_) > 180]))
        return angle_
          
    def calc_comparison_loss_with_permutation_rotation(self, pred_angles_expectation, relative_angle, n_ss_p ):
        if self.args.restrain_to_front:
            criterion_for_angle_comp = nn.L1Loss(reduction='sum') 
            pred_angle_comps = []; rough_angle_comps = []
            # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
            # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
            for i in range(0, n_ss_p):
                for j in range(i+1, n_ss_p):
                    # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                    pred_angle_comp = (pred_angles_expectation[i]-pred_angles_expectation[j]).float()
                    angle_comparison = relative_angle[i] - relative_angle[j]
                    rough_angle_comp = self.trunk_rough_angle(angle_comparison)
                    pred_angle_comps.append(pred_angle_comp)
                    rough_angle_comps.append(rough_angle_comp)
            pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
            rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
            angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
            return angle_comp_loss
        
        criterion_for_angle_comp = nn.L1Loss(reduction='none') 
        pred_angle_comps = []; rough_angle_comps = []
        # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
        # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
        for mod in [0,1,2,3]:
            for i in range(0, n_ss_p):
                for j in range(i+1, n_ss_p):
                    # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                    if mod == 0:
                        angle_1 = pred_angles_expectation[i]
                        angle_2 = pred_angles_expectation[j]
                    elif mod == 1:
                        angle_1 = pred_angles_expectation[i]
                        angle_2 = self.inverse_sound_direction(pred_angles_expectation[j])
                    elif mod == 2:
                        angle_1 = self.inverse_sound_direction(pred_angles_expectation[i])
                        angle_2 = pred_angles_expectation[j]
                    elif mod == 3:
                        angle_1 = self.inverse_sound_direction(pred_angles_expectation[i])
                        angle_2 = self.inverse_sound_direction(pred_angles_expectation[j])

                    angle_comparison = relative_angle[j] - relative_angle[i]
                    # rough_angle_comp = self.restrict_to_180(self.trunk_rough_angle(angle_comparison))
                    rough_angle_comp = self.trunk_rough_angle(angle_comparison, return_degree=False)


                    pred_angle_comp = (angle_2-angle_1).float()
                    pred_angle_comps.append(pred_angle_comp)
                    rough_angle_comps.append(rough_angle_comp)
                        
        pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
        rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
        rots = [torch.cos(rough_angle_comp).unsqueeze(-1), -torch.sin(rough_angle_comp).unsqueeze(-1),
                torch.sin(rough_angle_comp).unsqueeze(-1), torch.cos(rough_angle_comp).unsqueeze(-1)]
        rots = torch.cat(rots, dim=-1)
        rots = rots.contiguous().view(-1, 2, 2)
        
        angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
        angle_comp_loss = angle_comp_loss.view(4,n_ss_p*(n_ss_p-1)//2,-1)
        angle_comp_loss = angle_comp_loss.min(dim=0)[0]
        angle_comp_loss = angle_comp_loss.sum()
        return angle_comp_loss  

    def calc_comparison_loss_with_permutation_angle(self, pred_angles_expectation, relative_angle, n_ss_p):
        if self.args.restrain_to_front:
            criterion_for_angle_comp = nn.L1Loss(reduction='sum') 
            pred_angle_comps = []; rough_angle_comps = []
            # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
            # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
            for i in range(0, n_ss_p):
                for j in range(i+1, n_ss_p):
                    # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                    pred_angle_comp = (pred_angles_expectation[i]-pred_angles_expectation[j]).float()
                    angle_comparison = relative_angle[i] - relative_angle[j]
                    if self.loss_type == 'debug01':
                        angle_comparison = torch.where(angle_comparison > 0, torch.tensor(90).to(angle_comparison.device), torch.tensor(-90).to(angle_comparison.device))
                        rough_angle_comp = angle_comparison
                    else:
                        rough_angle_comp = self.trunk_rough_angle(angle_comparison)
                    pred_angle_comps.append(pred_angle_comp)
                    rough_angle_comps.append(rough_angle_comp)
            pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
            rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
            angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
            return angle_comp_loss
        
        criterion_for_angle_comp = nn.L1Loss(reduction='none') 
        

        pred_angle_comps = []; rough_angle_comps = []
        ref_angle = []; tgt_angle = []
        # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
        # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
        for mod in [0]:
            for i in range(0, n_ss_p):
                for j in range(i+1, n_ss_p):
                    # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                    if mod == 0:
                        angle_1 = pred_angles_expectation[i]
                        angle_2 = pred_angles_expectation[j]
                    elif mod == 1:
                        angle_1 = pred_angles_expectation[i]
                        angle_2 = self.inverse_sound_direction(pred_angles_expectation[j])
                    elif mod == 2:
                        angle_1 = self.inverse_sound_direction(pred_angles_expectation[i])
                        angle_2 = pred_angles_expectation[j]
                    elif mod == 3:
                        angle_1 = self.inverse_sound_direction(pred_angles_expectation[i])
                        angle_2 = self.inverse_sound_direction(pred_angles_expectation[j])

                    ref_angle.append(angle_1)
                    tgt_angle.append(angle_2)
                    angle_comparison = relative_angle[j] - relative_angle[i]
                    rough_angle_comp = self.restrict_to_180(self.trunk_rough_angle(angle_comparison))
                    # rough_angle_comp = self.trunk_rough_angle(angle_comparison)
                    rough_angle_comp = angle_comparison
                    pred_angle_comp = (angle_2-angle_1).float()
                    pred_angle_comps.append(pred_angle_comp)
                    rough_angle_comps.append(rough_angle_comp)
                        
        # pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
        def inverse_sound_direction(pred):
            pred_inv = torch.clone(pred)
            pred_inv[pred_inv <= 0.] = - math.pi - pred_inv[pred_inv <= 0]
            pred_inv[pred_inv > 0.] = math.pi - pred_inv[pred_inv > 0]
            return pred_inv
    
        rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
        ref_angle = torch.stack(ref_angle, dim=0)
        ref_angle_inv = inverse_sound_direction(ref_angle)
        ref_angle = torch.cat([
            ref_angle.unsqueeze(-1), ref_angle_inv.unsqueeze(-1),
            ref_angle.unsqueeze(-1), ref_angle_inv.unsqueeze(-1)], dim=-1)
        ref_angle = ref_angle.view(-1)
        
        tgt_angle = torch.stack(tgt_angle, dim=0)
        tgt_angle_inv = inverse_sound_direction(tgt_angle)
        tgt_angle = torch.cat([
            tgt_angle.unsqueeze(-1), tgt_angle_inv.unsqueeze(-1),
            tgt_angle.unsqueeze(-1), tgt_angle_inv.unsqueeze(-1)], dim=-1)
        tgt_angle = tgt_angle.view(-1)
        
        theta = rough_angle_comp
        rots = [torch.cos(theta).unsqueeze(-1), -torch.sin(theta).unsqueeze(-1), 
                torch.sin(theta).unsqueeze(-1), torch.cos(theta).unsqueeze(-1)]
        rots = torch.cat(rots, dim=-1)
        rots = rots.contiguous().view(-1, 2, 2).float()
        rots = rots.repeat_interleave(4, dim=0) # match the permutation number
         
        ref_vec = [torch.cos(ref_angle).unsqueeze(-1), torch.sin(ref_angle).unsqueeze(-1)]
        ref_vec = torch.cat(ref_vec, dim=-1)
        tgt_vec = [torch.cos(tgt_angle).unsqueeze(-1), torch.sin(tgt_angle).unsqueeze(-1)]
        tgt_vec = torch.cat(tgt_vec, dim=-1)
        rotated_ref_vec = torch.matmul(rots, ref_vec.unsqueeze(-1)).squeeze(-1)
        dot_product = (rotated_ref_vec * tgt_vec).sum(dim=-1)
        dot_product_target = torch.ones_like(dot_product)
        
        geometric_loss = F.l1_loss(dot_product, dot_product_target, reduction='none')
        geometric_loss =  geometric_loss.view(geometric_loss.shape[0] // 4, 4, -1).mean(dim=-1)
        geometric_loss = geometric_loss.min(dim=-1)[0]
        geometric_loss = geometric_loss.mean(dim=-1)

        return geometric_loss
            
        # angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
        # angle_comp_loss = angle_comp_loss.view(4,n_ss_p*(n_ss_p-1)//2,-1)
        # angle_comp_loss = angle_comp_loss.min(dim=0)[0]
        # angle_comp_loss = angle_comp_loss.sum()
        # return angle_comp_loss  
    
    def calc_comparison_loss_with_rotation(self, pred_angles_expectation, relative_angle, n_ss_p):
        if self.args.restrain_to_front:
            criterion_for_angle_comp = nn.L1Loss(reduction='sum') 
            pred_angle_comps = []; rough_angle_comps = []
            # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
            # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
            for i in range(0, n_ss_p):
                for j in range(i+1, n_ss_p):
                    # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                    pred_angle_comp = (pred_angles_expectation[i]-pred_angles_expectation[j]).float()
                    angle_comparison = relative_angle[i] - relative_angle[j]
                    rough_angle_comp = self.trunk_rough_angle(angle_comparison)
                    pred_angle_comps.append(pred_angle_comp)
                    rough_angle_comps.append(rough_angle_comp)
            pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
            rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
            angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
            return angle_comp_loss
        # calculate rotation     
        
        pred_angle_2_sin = []; gt_angle_2_sin = []
        pred_angle_2_cos = []; gt_angle_2_cos = []
        for i in range(0, n_ss_p):
            for j in range(i+1, n_ss_p):
                # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                angle_1 = pred_angles_expectation[i]* torch.pi / 180
                angle_2 = pred_angles_expectation[j]* torch.pi / 180

                gt_sin_angle2 = torch.sin(angle_2).to(torch.float32);gt_cos_angle2 = torch.cos(angle_2).to(torch.float32)
                gt_angle_2_sin.append(gt_sin_angle2);gt_angle_2_cos.append(gt_cos_angle2)
                
                angle_comparison = relative_angle[j] - relative_angle[i]
                # rough_angle_comp = self.restrict_to_180(self.trunk_rough_angle(angle_comparison))
                rough_angle_comp = self.trunk_rough_angle(angle_comparison)

                sin_angle_1 = torch.sin(angle_1).to(torch.float32);cos_angle_1 = torch.cos(angle_1).to(torch.float32)
                sin_angle_comparison = torch.sin(rough_angle_comp).to(torch.float32);cos_angle_comparison = torch.cos(rough_angle_comp).to(torch.float32)
                pred_sin_angle2 = sin_angle_1*cos_angle_comparison + cos_angle_1*sin_angle_comparison
                pred_cos_angle2 = cos_angle_1*cos_angle_comparison - sin_angle_1*sin_angle_comparison
                pred_angle_2_sin.append(pred_sin_angle2);pred_angle_2_cos.append(pred_cos_angle2)
            
        # import pdb; pdb.set_trace()    
        pred_angle_2_sin = torch.cat(pred_angle_2_sin, dim=0)
        pred_angle_2_cos = torch.cat(pred_angle_2_cos, dim=0)
        gt_angle_2_sin = torch.cat(gt_angle_2_sin, dim=0)
        gt_angle_2_cos = torch.cat(gt_angle_2_cos, dim=0)
        
        criterion_for_sin= nn.L1Loss(reduction='sum') 
        criterion_for_cos = nn.L1Loss(reduction='sum') 
        sin_loss = criterion_for_sin(pred_angle_2_sin, gt_angle_2_sin)
        cos_loss = criterion_for_cos(pred_angle_2_cos, gt_angle_2_cos)
        angle_comp_loss = sin_loss.to(torch.float32) + cos_loss.to(torch.float32)
        
        return angle_comp_loss
                

    def calc_comparison_loss_with_atan(self, pred_angles_expectation, relative_angle, n_ss_p):
        # calculate rotation     
        
        pred_angle_2_sin = []; gt_angle_2_sin = []
        pred_angle_2_cos = []; gt_angle_2_cos = []
        for i in range(0, n_ss_p):
            for j in range(i+1, n_ss_p):
                # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                angle_1 = pred_angles_expectation[i]* torch.pi / 180
                angle_2 = pred_angles_expectation[j]* torch.pi / 180

                gt_sin_angle2 = torch.sin(angle_2);gt_cos_angle2 = torch.cos(angle_2)
                gt_angle_2_sin.append(gt_sin_angle2);gt_angle_2_cos.append(gt_cos_angle2)
                
                angle_comparison = relative_angle[j] - relative_angle[i]
                rough_angle_comp = self.restrict_to_180(self.trunk_rough_angle(angle_comparison))

                sin_angle_1 = torch.sin(angle_1);cos_angle_1 = torch.cos(angle_2)
            
        # import pdb; pdb.set_trace()    
        pred_angle_2_sin = torch.cat(pred_angle_2_sin, dim=0)
        pred_angle_2_cos = torch.cat(pred_angle_2_cos, dim=0)
        gt_angle_2_sin = torch.cat(gt_angle_2_sin, dim=0)
        gt_angle_2_cos = torch.cat(gt_angle_2_cos, dim=0)
        
        criterion_for_sin= nn.L1Loss(reduction='sum') 
        criterion_for_cos = nn.L1Loss(reduction='sum') 
        sin_loss = criterion_for_sin(pred_angle_2_sin, gt_angle_2_sin)
        cos_loss = criterion_for_cos(pred_angle_2_cos, gt_angle_2_cos)
        angle_comp_loss = sin_loss.to(torch.float32) + cos_loss.to(torch.float32)
        
        return angle_comp_loss
    
        
            
    def logit2angle(self,pred_angle,n_ss_p): # [angle_1_1, angle_1_2, angle_2_1, angle_2_2, ..., angle_B_1, angle_B_2]
        
        if self.azimuth_loss_type == 'regression':
            pred_angle = pred_angle.view(pred_angle.shape[0]//n_ss_p, n_ss_p, -1)
            pred_angles = [pred_angle[:,i,:] for i in range(0, n_ss_p)]
            pred_angles_expectation = [torch.tanh(pred_angle[:,i,:].squeeze(-1))*self.angle_range_max for i in range(0, n_ss_p)]   # [B]
        elif self.azimuth_loss_type == 'classification':
            pred_angle = pred_angle.view(pred_angle.shape[0]//n_ss_p, n_ss_p, -1); pred_angle =  F.softmax(pred_angle, dim=2)
            pred_angles = [pred_angle[:,i,:] for i in range(0, n_ss_p)]
            exp = torch.tensor([-self.angle_range_max + 2*self.angle_range_max//self.n_clf * (i+0.5) for i in range(0, self.n_clf)]).to(pred_angle.device)
            def angle_expectation(pred_angles):
                x = torch.mv(pred_angles, exp)
                return torch.clamp(x, -self.angle_range_max, self.angle_range_max)
            pred_angles_expectation = [angle_expectation(pred_angles[i]) for i in range(0, n_ss_p)]
        return pred_angles,pred_angles_expectation
               
    def get_relative_angle(self,inputs):
        relative_angle = []
        n_ss_p = self.args.n_ss_position
        if self.dataset_mod in ['0','7']:
            for i in range(n_ss_p):
                relative_angle.append(inputs[f'source_{i+1}_relative_angle'])
        elif self.dataset_mod in ['1','2']:
            for i in range(n_ss_p):
                relative_angle.append(inputs[f'source_{i+1}_position']) 
        return relative_angle
                   
    def calc_loss(self, inputs, pred_angle,dataset_mod,evaluate=False, return_angle=False):
        output = {}
        self.dataset_mod = dataset_mod
        n_ss_p = self.args.n_ss_position            
        

        # Ground Truth
        relative_angle = self.get_relative_angle(inputs) # [angle_1, angle_2, ...]
        
        # if self.loss_type in ['v00613','ve01'] and self.args.dataset_mod in ['0'] and self.azimuth_loss_type == 'classification':
        #     if n_ss_p == 2:
        #         batch_size = pred_angle.shape[0]
        #         assert batch_size % (2*n_ss_p)  == 0
        #         n_ss_p *= 2
        #         concatenated_relative_angle_tensor = torch.cat(relative_angle, dim=0).view(2,-1).T.contiguous().view(-1,n_ss_p).T
        #         relative_angle = [concatenated_relative_angle_tensor[i, :] for i in range(0, n_ss_p)] # [[angle_1_1 ...], [angle_1_2 ...], [angle_2_1 ...], [angle_2_2 ...]]
            
        pred_angles,pred_angles_expectation = self.logit2angle(pred_angle,n_ss_p)
        # Angle Direction Loss: whether left or right
        pred_angle_direction = self.get_pred_angle_direction(pred_angles, pred_angles_expectation, n_ss_p)
        real_angle_direction = self.get_real_angle_direction(relative_angle, n_ss_p)
        angle_direction_loss = self.calculate_angle_direction_loss(pred_angle_direction, real_angle_direction, n_ss_p)

        if return_angle:
            if self.n_clf == 1 and self.loss_type in ['v0040','v0041','v0050']:
                return pred_angle_direction
            return torch.cat(pred_angles_expectation, dim=0)

        # Binaural Loss
        binaural_ild_loss,ild_binaural_result,bina_pred = self.calc_binaural_loss(inputs, pred_angles_expectation, n_ss_position = n_ss_p)  # range: _,[0,1],[0,1]
        ild_angle_direction = [ild_binaural_result[:,i] for i in range(0, n_ss_p)]
        criterion_for_ild_angle_direction = nn.L1Loss(reduction='mean')
        binaural_ild_loss = sum([criterion_for_ild_angle_direction(pred_angle_direction[i], ild_angle_direction[i])  for i in range(0, n_ss_p)])

        # Angle Comparison Loss4
        angle_comp_loss = self.calc_comparison_loss_with_permutation_angle(pred_angles_expectation, relative_angle, n_ss_p)
        # angle_comp_loss = self.calc_comparison_loss_with_rotation(pred_angles_expectation, relative_angle, n_ss_p)
        
        def get_direction_mask(input_tensor, n_clf):
            if n_clf%2 == 0:
                n = n_clf//2
                output_tensor = torch.zeros(input_tensor.shape[0], n_clf)

                output_tensor[input_tensor == 1, :n] = 0
                output_tensor[input_tensor == 1, n:2*n] = 1
                output_tensor[input_tensor == 0, :n] = 1
                output_tensor[input_tensor == 0, n:2*n] = 0
            else:
                n = n_clf//2
                output_tensor = torch.zeros(600, n_clf)

                output_tensor[input_tensor == 1, :n+1] = 0
                output_tensor[input_tensor == 1, n+1:2*n+1] = 1
                output_tensor[input_tensor == 0, :n] = 1
                output_tensor[input_tensor == 0, n:2*n+1] = 0
            return output_tensor
        
        

        # label: > or <
        def cal_sign_comp_loss(pred_angles, relative_angle, n_ss_p):
            assert self.azimuth_loss_type == 'classification'
            # assert self.args.restrain_to_front # TODO
            # self.relative_angle 
            # self.angle_range_max//self.n_clf 
            
            def get_probabilities(angle_comparison):
                n_clf = self.n_clf
                group_num = pred_angles[0].shape[0]
                tensor = angle_comparison
                result = torch.zeros(group_num, n_clf, n_clf)
                mask_gt_zero = tensor > 0
                result[mask_gt_zero, :, :] = torch.triu(torch.ones(n_clf, n_clf), diagonal=1)
                mask_lt_zero = tensor < 0
                result[mask_lt_zero, :, :] = torch.tril(torch.ones(n_clf, n_clf), diagonal=-1)
                result = result*2/(n_clf-1)/(n_clf-1)
                return result.to(pred_angles[0].device)
            
            def get_mask(angle_comparison):
                if self.args.restrain_to_front:
                    if self.loss_type in ['v0061','v00612','v00613']:
                        diagonal_param = [1,-1]
                    elif self.loss_type == 'v00611':
                        diagonal_param = [0,0]
                    n_clf = self.n_clf
                    group_num = pred_angles[0].shape[0]
                    tensor = angle_comparison
                    result = torch.zeros(group_num, n_clf, n_clf)
                    mask_gt_zero = tensor > 0
                    result[mask_gt_zero, :, :] = torch.triu(torch.ones(n_clf, n_clf), diagonal=diagonal_param[0])
                    mask_lt_zero = tensor < 0
                    result[mask_lt_zero, :, :] = torch.tril(torch.ones(n_clf, n_clf), diagonal=diagonal_param[1])
                    pred_angle_comp_masked = pred_angle_comp * result.to(pred_angles[0].device)
                else:
                    if self.loss_type in ['v0061','v00612','v00613']:
                        diagonal_param = [1,-1]
                    n_clf = self.n_clf
                    group_num = pred_angles[0].shape[0]
                    tensor = angle_comparison
                    result = torch.zeros(group_num, n_clf, n_clf)
                    
                    def get_mask_for_360(flag='tril'):
                        if flag == 'tril':
                            x = torch.zeros(1, n_clf, n_clf)
                            piece = n_clf//4
                            x[0, piece : piece*3, piece : piece*3] = torch.tril(torch.ones(piece*2, piece*2), diagonal=diagonal_param[1])
                            x[0, piece*2 : , : piece*2] = torch.ones(piece*2, piece*2)
                            x[0, :piece, :piece] = torch.triu(torch.ones(piece, piece), diagonal=diagonal_param[0])
                            x[0, piece*3: , piece*3:] = torch.triu(torch.ones(piece, piece), diagonal=diagonal_param[0])
                            x[0 , :piece , piece : piece*2] =  torch.flip(x[0, :piece, :piece], dims=[1])
                            x[0, piece*3: , piece*2:piece*3] =x[0 , :piece , piece : piece*2]
                            x[0, piece: piece*2,:piece ] = torch.flip(x[0, :piece, :piece], dims=[0])
                            x[0, piece*2:piece*3, piece*3: ] = x[0, piece: piece*2,:piece ]
                        elif flag == 'triu':
                            x = torch.zeros(1, n_clf, n_clf)
                            piece = n_clf//4
                            x[0, piece : piece*3, piece : piece*3] = torch.triu(torch.ones(piece*2, piece*2), diagonal=diagonal_param[0])
                            x[0, : piece*2,  piece*2 : ] = torch.ones(piece*2, piece*2)
                            x[0, :piece, :piece] = torch.tril(torch.ones(piece, piece), diagonal=diagonal_param[1])
                            x[0, piece*3: , piece*3:] = torch.tril(torch.ones(piece, piece), diagonal=diagonal_param[1])
                            x[0 , :piece , piece : piece*2] =  torch.flip(x[0, :piece, :piece], dims=[1])
                            x[0, piece*3: , piece*2:piece*3] =x[0 , :piece , piece : piece*2]
                            x[0, piece: piece*2,:piece ] = torch.flip(x[0, :piece, :piece], dims=[0])
                            x[0, piece*2:piece*3, piece*3: ] = x[0, piece: piece*2,:piece ]
                        return x.unsqueeze(0)
                    mask_gt_zero = tensor > 0 # clockwise rotation
                    result[mask_gt_zero, :, :] = get_mask_for_360(flag='triu')
                    mask_lt_zero = tensor < 0
                    result[mask_lt_zero, :, :] = get_mask_for_360(flag='tril')
                    pred_angle_comp_masked = pred_angle_comp * result.to(pred_angles[0].device)
                return pred_angle_comp_masked
            
            if True:
            # if self.args.restrain_to_front:
                pred_angle_comps = []; rough_angle_comps = []
                # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
                # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
                for i in range(0, n_ss_p):
                    for j in range(i+1, n_ss_p):
                        # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                        pred_angle_comp = (pred_angles[i].unsqueeze(1)*pred_angles[j].unsqueeze(2)).float()
                        angle_comparison = (relative_angle[i] - relative_angle[j] +180)%360 - 180
                        if self.loss_type == 'v0060':
                            angle_comparison = get_probabilities(angle_comparison)
                            rough_angle_comp = angle_comparison
                        elif self.loss_type in ['v0061','v00611','v00612','v00613']:
                            angle_comparison = get_mask(angle_comparison)
                            if self.loss_type == 'v00612':
                                IID_mask = get_direction_mask(ild_angle_direction[i], self.n_clf).unsqueeze(1)*get_direction_mask(ild_angle_direction[j], self.n_clf).unsqueeze(2)
                                IID_mask = IID_mask.to(angle_comparison.device)
                                angle_comparison = angle_comparison * IID_mask
                            rough_angle_comp = angle_comparison
                        else:
                            rough_angle_comp = self.trunk_rough_angle(angle_comparison,n_comp_clf=self.n_clf)
                        pred_angle_comps.append(pred_angle_comp)
                        rough_angle_comps.append(rough_angle_comp)
                        
                pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
                rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
                if self.loss_type == 'v0060':
                    criterion_for_angle_comp =  nn.CrossEntropyLoss()
                    sign_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
                elif self.loss_type in ['v0061','v00611','v00612','v00613','v0062']:
                    criterion_for_angle_comp = nn.L1Loss(reduction='sum'); 
                    angle_comparison = torch.sum(angle_comparison, dim=(1,2), keepdim=True)
                    gt_sum = torch.ones(angle_comparison.shape).to(angle_comparison.device)
                    sign_comp_loss = criterion_for_angle_comp(angle_comparison, gt_sum)
                return sign_comp_loss
        
        
        def cal_rotation_angle_consistency():
            if self.args.restrain_to_front:
                criterion_for_angle_comp = nn.L1Loss(reduction='sum') 
                pred_angle_comps = []; rough_angle_comps = []
                # pred_angle_comp = torch.sigmoid(pred_angles_expectt5ation[0]-pred_angles_expectation[1]).float()
                # rough_angle_comp = restrain_to_0_1(angle_comparison).float()
                for i in range(0, n_ss_p):
                    for j in range(i+1, n_ss_p):
                        # relative_angle[i] = restrain_to_0_1(relative_angle[i]).float()
                        pred_angle_comp = (pred_angles_expectation[i]-pred_angles_expectation[j]).float()
                        angle_comparison = relative_angle[i] - relative_angle[j]
                        if self.loss_type == 'debug01':
                            angle_comparison = torch.where(angle_comparison > 0, torch.tensor(90).to(angle_comparison.device), torch.tensor(-90).to(angle_comparison.device))
                            rough_angle_comp = angle_comparison
                        else:
                            rough_angle_comp = self.trunk_rough_angle(angle_comparison)
                        pred_angle_comps.append(pred_angle_comp)
                        rough_angle_comps.append(rough_angle_comp)
                pred_angle_comp = torch.cat(pred_angle_comps, dim=0)
                rough_angle_comp = torch.cat(rough_angle_comps, dim=0)
                angle_comp_loss = criterion_for_angle_comp(pred_angle_comp, rough_angle_comp)
                return angle_comp_loss
        
        # Direct Angle Loss
        direct_angle_loss = sum([F.l1_loss(pred_angles_expectation[i].float(), relative_angle[i].float(),reduction='none')  for i in range(0, n_ss_p)])
        # direct_angle_loss = F.l1_loss(pred_angles_expectation[0].float(), relative_angle[0].float(),reduction='none') 
        # rough_angle_loss = sum([F.l1_loss(self.trunk_rough_angle(pred_angles_expectation[i].float()), relative_angle[i].float(),reduction='none')  for i in range(0, n_ss_p)])
        rough_angle_loss = sum([F.l1_loss(pred_angles_expectation[i].float(), self.trunk_rough_angle(relative_angle[i].float(),n_comp_clf=self.n_clf),reduction='none')  for i in range(0, n_ss_p)])
        
        # evaluate in one half
        evaluate_in_one_half = True
        
        output['binaural iid loss'] = binaural_ild_loss
        output['angle direction loss(left/right)'] = angle_direction_loss
        output['angle comparison MAE'] = angle_comp_loss/(n_ss_p*(n_ss_p-1)/2)
        output['angle comparison loss'] = angle_comp_loss
        output['angle MAE'] = direct_angle_loss/n_ss_p
        
        if not evaluate_in_one_half:
            pass
        else:
            def restrain_to_front(angle):
                angle[angle > 90] = 180 - angle[angle > 90]
                angle[angle < -90] = angle[angle < -90] + 90
                return angle
            output['angle MAE one side'] = sum([F.l1_loss(restrain_to_front(pred_angles_expectation[i].float()), restrain_to_front(relative_angle[i].float()),reduction='none')  for i in range(0, n_ss_p)])/n_ss_p
        
        iid_loss_type = ['vd01']
        supervised_loss_type = ['ve01','ve011','ve012']; supervised_clf_loss_type = ['ve02','ve03']
        only_comp_loss_type = ['va01','v00212','v00213']
        rough_loss_type_precise_direction = ['v00112', 'v00111', 'v0012','v00121','v00211','v002111','v002121','v002131','v0030']
        rough_loss_type_ILD_direction = ['v0040','v0041','v0050']
        rotation_direction_type = ['v0060','v0061','v00611','v00612','v00613','v0062','v0065']
        debug_loss_type = ['debug01','debug02']
        
        if self.loss_type in supervised_clf_loss_type and self.azimuth_loss_type == 'classification':
            if self.loss_type == 've02':
                criterion = nn.CrossEntropyLoss()
                probabilities = F.softmax(pred_angles[0], dim=1)
                target_labels = self.calc_source_direction_to_bin(relative_angle[0].cpu()).long().to(pred_angles_expectation[0].device)
                loss = criterion(probabilities, target_labels)
            if self.loss_type == 've03':
                loss = rough_angle_loss
            
        if self.loss_type in supervised_loss_type:
            loss = direct_angle_loss
        if self.loss_type in only_comp_loss_type:
            loss = angle_comp_loss
        if self.loss_type in rough_loss_type_precise_direction:
            if dataset_mod=='0':
                loss = binaural_ild_loss + angle_comp_loss*0
            elif dataset_mod in ['1','2','7']:
                loss = binaural_ild_loss + 0.5*angle_comp_loss
        if self.loss_type in rough_loss_type_ILD_direction:
            if dataset_mod=='0':
                loss = binaural_ild_loss*5 + 0.1*angle_comp_loss
            elif dataset_mod in ['1','2']:
                # loss = binaural_ild_loss + 0.02*angle_comp_loss  s
                # loss = binaural_ild_loss*5 + 0.000002*angle_comp_loss         
                # loss = binaural_ild_loss*5 + 0.00000002*angle_comp_loss   
                loss =  0.00000002*angle_comp_loss        
        if self.loss_type == 'vc01':
            loss = angle_direction_loss 
        if self.loss_type == 'vd01':
            loss = binaural_ild_loss
        if self.loss_type in rotation_direction_type:
            sign_comp_loss = cal_sign_comp_loss(pred_angles, relative_angle, n_ss_p)
            if dataset_mod=='0' and self.loss_type == 'v0060':
                loss = angle_direction_loss*3 + sign_comp_loss*10    
            if self.loss_type in ['v0061','v00611','v00612']:
                loss = binaural_ild_loss*10 + sign_comp_loss*0
            if dataset_mod=='0' and self.loss_type in ['v00613']:
                loss = binaural_ild_loss*1000 + sign_comp_loss*5
            if self.loss_type in ['v0062']:
                loss = sign_comp_loss*5
            if self.loss_type == 'v0065':
                loss = angle_direction_loss*3 + sign_comp_loss*10
            if evaluate and self.azimuth_loss_type == 'classification':
                output['sign comparison loss'] = sign_comp_loss
            
        # -----for evaluation-----
        if evaluate and self.azimuth_loss_type == 'classification':
            if evaluate_in_one_half:
                def get_accuracy_rate(n_bin):
                    relative_angle_bin = self.calc_source_direction_to_bin(restrain_to_front(relative_angle[0].cpu()),n_bin=n_bin,angle_range_max=90)
                    pred_angle_bin = self.calc_source_direction_to_bin(restrain_to_front(pred_angles_expectation[0].cpu()),n_bin=n_bin,angle_range_max=90)
                    accuracy = (relative_angle_bin == pred_angle_bin).sum().item() / len(relative_angle_bin)
                    return accuracy
                output['accuracy rate, 16clf, in one side'] = torch.tensor(get_accuracy_rate(n_bin=16)).to(angle_comp_loss.device)  
                output['accuracy rate, 8clf, in one side'] = torch.tensor(get_accuracy_rate(n_bin=8)).to(angle_comp_loss.device)  
                output['accuracy rate, 4clf, in one side'] = torch.tensor(get_accuracy_rate(n_bin=4)).to(angle_comp_loss.device)      
                output['accuracy rate, 2clf, in one side'] = torch.tensor(get_accuracy_rate(n_bin=2)).to(angle_comp_loss.device)            
            else:
                relative_angle_bin = self.calc_source_direction_to_bin(relative_angle[0].cpu())
                pred_angle_bin = self.calc_source_direction_to_bin(pred_angles_expectation[0].cpu())
                accuracy = (relative_angle_bin == pred_angle_bin).sum().item() / len(relative_angle_bin)
                output['accuracy rate'] = torch.tensor(accuracy).to(angle_comp_loss.device)
            
            # if self.loss_type == 'vd01':
            if True:
                ild_accuracys = []
                
                for i in range(0, n_ss_p):
                    binau_relative_angle_bin = self.calc_source_direction_to_bin(relative_angle[i].cpu(),n_bin=2) 
                    ild_r = ild_binaural_result[:,i].detach().cpu()
                    ild_accuracy = (binau_relative_angle_bin == ild_r).sum().item() / len(binau_relative_angle_bin)
                    ild_accuracys.append(ild_accuracy)
                ild_accuracys = sum(ild_accuracys)/n_ss_p
                output['ild accuracy rate(anno==ild/anno)']= torch.tensor(ild_accuracys).to(angle_comp_loss.device)
                
                binaur_accuracys = [];ild_pred_similarities = []
                threshold = 0.5
                for i in range(0, n_ss_p):
                    # if self.n_clf == 1:
                    #     pred = self.restrain_to_0_1(bina_pred[:,0].detach())
                    # else:
                        # pred = (pred_angles[i] > threshold).to(torch.float32)[:,0].to(angle_comp_loss.device)
                    pred = self.calc_source_direction_to_bin(pred_angles_expectation[i].cpu())
                    binaur_accuracy = (pred == real_angle_direction[i].cpu()).sum().item() / len(real_angle_direction[i])
                    binaur_accuracys.append(binaur_accuracy)
                    ild_pred_similarity = (pred == ild_binaural_result[:,0].cpu()).sum().item() / len(pred)
                    ild_pred_similarities.append(ild_pred_similarity)
                    
                binaur_accuracy = sum(binaur_accuracys)/n_ss_p
                ild_pred_similarity = sum(ild_pred_similarities)/n_ss_p
                output['binaur accuracy rate(pred==anno/anno)']= torch.tensor(binaur_accuracy).to(angle_comp_loss.device) # [pred,annotation]
                output['ild simi rate(pred==ild/ild)']= torch.tensor(ild_pred_similarity).to(angle_comp_loss.device)
                
        
        output['Loss'] = loss
        if evaluate:
            return output
        return loss

    def restrain_to_0_1(self,tensor):
        tensor = torch.clamp(tensor, min=0)
        tensor = torch.clamp(tensor, min=0, max=1e-10)
        tensor *= 1e10
        return tensor
        
    def get_pred_angle_direction(self, pred_angles, pred_angles_expectation, n_ss_p):
        if self.n_clf != 1:
            pred_angle_direction = [torch.sigmoid(pred_angles_expectation[i]) for i in range(n_ss_p)]
        else:
            pred_angle_direction = [torch.sigmoid(pred_angles[i])[:, 0] for i in range(n_ss_p)]
        return pred_angle_direction

    def get_real_angle_direction(self, relative_angle, n_ss_p):
        return [self.restrain_to_0_1(relative_angle[i]).float() for i in range(n_ss_p)]

    def calculate_angle_direction_loss(self, pred_angle_direction, real_angle_direction, n_ss_p):
        return sum([F.binary_cross_entropy(pred_angle_direction[i], real_angle_direction[i], reduction='mean')
                    for i in range(n_ss_p)])
        
    def calc_source_direction_to_bin(self, source_angle, n_bin = None, angle_range_max=None):
        '''
            We define turning left as +, turning right as -.
            source angle are within (-180, 180]
        '''
        if n_bin is None:
            n_bin = self.n_clf
        if angle_range_max is None:
            max = self.angle_range_max
            min = -self.angle_range_max
        else:
            max = angle_range_max
            min = -angle_range_max
        angle_range = max - min
        bin_size = angle_range / n_bin    
        source_angle_bin = torch.div(source_angle - min, bin_size, rounding_mode='floor')
        source_angle_bin = np.clip(source_angle_bin, 0, n_bin - 1)
        return source_angle_bin


    def calc_binaural_loss(self, inputs, pred_angles, n_ss_position): # gt_angles: (1,2,3,)
        '''
            We calcualte the binaural cue loss with a weak supervision: whether sound is on the left or right
            sound direction: left as +, right as -; return range[0,1]
        '''
        angles = [x.unsqueeze(1) for x in pred_angles]
        angles = torch.cat(angles, dim=1)
        angles = (torch.sin(angles) + 1) / 2
        # angles = angles.view(-1)
        # angles = pred_angles.squeeze(-1)
        
        if False:
            # self.loss_type in ['v00613','ve01'] and self.args.dataset_mod in ['0'] and self.azimuth_loss_type == 'classification':
            audios = [inputs[f'audio_{i+1}'].unsqueeze(1) for i in range(n_ss_position//2)] # [audio_1, audio_2]
            # audios_ = torch.cat(audios, dim=0).view(300, 4, audios[0].shape[1], audios[0].shape[2], -1)
            # audios = [audios_[:,i,:,:,:] for i in range(4)]
            audios = [audios[0][::2].clone(), audios[1][0::2].clone(), audios[0][1::2].clone(),  audios[1][1::2].clone()]
        else:
            audios = [inputs[f'audio_{i+1}'].unsqueeze(1) for i in range(n_ss_position)]
        audios = torch.cat(audios, dim=1)
        # Re-cut the conditional audio clip to meet the requirement
        D = audios[..., :int(self.pr.clip_length * self.pr.samp_sr)]

        # advanced IID cues 
        audios = audios.contiguous().view(-1, *audios.shape[2:])
        audios = self.audio_net.wave2spec(audios, return_complex=True).abs()
        ild_cues = torch.log(audios[:, 0, :, :].mean(dim=-2) / audios[:, 1, :, :].mean(dim=-2))
        ild_cues = torch.sign(ild_cues).sum(dim=-1)
        ild_cues = torch.sign(ild_cues)
        ild_cues = ild_cues.view(-1 , n_ss_position)
        target = (ild_cues + 1) / 2

        # loss = F.binary_cross_entropy(angles, target.detach(), reduction='none').mean(-1)
        loss = F.l1_loss(angles, target.detach(), reduction='sum').mean(-1)
        
        return loss,target,angles
    
    def inference(self, inputs, pred_audio):
        # import pdb; pdb.set_trace()
        gt_audio = [inputs[f'audio_{i+1}'].unsqueeze(1) for i in range(1, self.n_view)]
        gt_audio = torch.cat(gt_audio, dim=1)
        audio_shape = gt_audio.shape 
        gt_audio = gt_audio.contiguous().view(-1, *audio_shape[2:])
        c = int(gt_audio.shape[1] // 2)

        if self.args.mono2binaural:
            audio_mix = gt_audio[:, :c, :] + gt_audio[:, c:, :]
            audio_input = self.wave2spec(audio_mix, return_complex=True).squeeze().detach()
            pred_audio = pred_audio.permute(0, 2, 3, 1)
            pred_audio = torch.view_as_complex(pred_audio.contiguous())
            pred_audio = torch.cat([pred_audio, audio_input[:, -1:, ...]], dim=1)
        
            pred_audio = torch.istft(
                input=pred_audio,
                n_fft=self.pr.n_fft,
                hop_length=self.pr.hop_length,
                win_length=self.pr.win_length
            ).unsqueeze(1)
            pred_left = (audio_mix + pred_audio) / 2
            pred_right = (audio_mix - pred_audio) / 2
            pred_audio = torch.cat([pred_left, pred_right], dim=1)
        else:
            raise NotImplementedError
        
        return {
            'pred_wave': pred_audio,
            'gt_wave': gt_audio
        }
    

    def encode_conditional_feature(self, inputs, cond_audio, augment):
        # import pdb; pdb.set_trace()
        # We always set the Img 1 as conditional view
        B = cond_audio.shape[0]

        # ------  Encode the conditional audio at the source viewpoint  --------- #
        if self.no_cond_audio:
            cond_audio_feat = None
        else:
            cond_audio_feat = self.audio_net(cond_audio, augment=augment)
            cond_audio_feat = torch.cat([cond_audio_feat.unsqueeze(1)] * (self.n_view - 1), dim=1)
            cond_audio_feat = cond_audio_feat.contiguous().view(-1, *cond_audio_feat.shape[2:])

        # ------  Encode the relative camera pose between different view  --------- #
        if self.no_vision:
            im_features = None
        else:
            single_im_features = []
            for i in range(0, self.n_view):
                im_feature = self.vision_net.forward_backbone(inputs[f'img_{i+1}'], augment=augment)
                single_im_features.append(im_feature)

            im_features = []
            for i in range(1, self.n_view):
                corr_feature = self.vision_net.forward_correlation(single_im_features[i], single_im_features[0])
                im_features.append(corr_feature.unsqueeze(1))

            im_features = torch.cat(im_features, dim=1)
            im_features = im_features.contiguous().view(-1, *im_features.shape[2:])

            if self.use_gt_rotation:
                theta = torch.cat([inputs[f'relative_camera{i}_angle'].unsqueeze(1) for i in range(1, self.n_view)], dim=1)
                theta = theta.contiguous().view(theta.shape[0] * theta.shape[1], -1)
                theta = theta / 180.0 * math.pi
                im_features = self.rota_embedding(theta.float()).detach()

        # ------  Concat the conditional features  --------- #
        if self.no_vision and not self.no_cond_audio:
            cond_feats = cond_audio_feat
        elif not self.no_vision and self.no_cond_audio:
            cond_feats = im_features
        elif not self.no_vision and not self.no_cond_audio:
            cond_feats = torch.cat([cond_audio_feat, im_features], dim=-1)
        else:
            cond_feats = None

        return cond_feats
    
    def generate_audio_pair(self, inputs): # [B,C,L]
        # import pdb; pdb.set_trace()
        target_view_audio = [inputs[f'audio_{i}'].unsqueeze(1) for i in range(1, self.args.n_ss_position + 1)]
        target_view_audio = torch.cat(target_view_audio, dim=1)
        audio_shape = target_view_audio.shape 
        target_view_audio = target_view_audio.contiguous().view(-1, *audio_shape[2:])

        c = int(target_view_audio.shape[1] // 2)
        if self.args.mono2binaural:
            pass # audio_mix = (target_view_audio[:, :c, :] + target_view_audio[:, c:, :])
            # audio_input = self.wave2spec(audio_mix, return_real_imag=True).detach()
            # audio_diff = (target_view_audio[:, :c, :] - target_view_audio[:, c:, :])
            # audio_output = self.wave2spec(audio_diff, return_real_imag=True).detach()
        else:
            # audio_input = self.wave2spec(target_view_audio[:, :, :], return_mag_phase = True).detach()
            audio_input = self.wave2spec(target_view_audio[:, :, :], return_mag_phase = False).detach()
            # audio_input = self.wave2spec(target_view_audio[:, :c, :], return_mag_phase=True).detach()
            # audio_output = self.wave2spec(target_view_audio[:, c:, :], return_mag_phase=True).detach()
        # cond_audio = inputs['audio_1']
        # return cond_audio, audio_input, audio_output
        return audio_input


    def wave2spec(self, wave, return_complex=False, return_real_imag=False, return_mag_phase=False):
        # import pdb; pdb.set_trace()
        N, C, _ = wave.shape
        wave = wave.view(N * C, -1)
        spec = torch.stft(
            input=wave,
            n_fft=self.pr.n_fft,
            hop_length=self.pr.hop_length,
            win_length=self.pr.win_length,
            return_complex=True
        )
        spec = spec.contiguous().view(N, C, *spec.shape[1:])
        if return_complex:
            return spec
        elif return_real_imag:
            spec = torch.view_as_real(spec)
            spec = spec.permute(0, 1, 4, 2, 3)
            spec = spec.view(N, -1, *spec.shape[3:])
        elif return_mag_phase:
            mag, phase = spec.abs().unsqueeze(-1), spec.angle().unsqueeze(-1)
            mag = self.normalize_magnitude(mag)
            phase = self.normalize_phase(phase)
            spec = torch.cat([mag, phase], dim=-1)
            spec = spec.permute(0, 1, 4, 2, 3)
            spec = spec.contiguous().view(N, -1, *spec.shape[3:])
        else:
            # return log magnitude
            spec = spec.abs()
            spec = self.normalize_magnitude(spec)
        # spec: (N, C, F-1, T)
        spec = spec[:, :, :-1, :]
        return spec


    def normalize_magnitude(self, spec, inverse=False):
        # import pdb; pdb.set_trace()
        spec_min = -100
        spec_max = 60
        if not inverse:
            spec = torch.maximum(spec, torch.tensor(self.pr.log_offset))
            spec = 20 * torch.log10(spec)
            spec = (spec - spec_min) / (spec_max - spec_min) * 2 - 1
            spec = torch.clip(spec, -1.0, 1.0)
            # spec = torch.log(spec + self.pr.log_offset)
        else:
            spec = (spec + 1) / 2
            spec = spec * (spec_max - spec_min) + spec_min
            spec = 10 ** (spec / 20)
        return spec


    def normalize_phase(self, phase, inverse=False):
        pi = 3.1416
        if not inverse:
            phase = phase / pi
            phase = torch.clip(phase, -1.0, 1.0)
        else:
            phase = phase * pi
        return phase


    def freeze_param(self, args):
        if args.freeze_camera:
            for param in self.vision_net.parameters():
                param.requires_grad = False
        if args.freeze_audio:
            for param in self.audio_net.parameters():
                param.requires_grad = False
        if args.freeze_generative:
            for param in self.generative_net.parameters():
                param.requires_grad = False


    def score_model_performance(self, res):
        score = 1 / res['Loss']
        return score
    
    def construct_feature_head(self, args, pr, n_clf):
        in_channels = 2 # TODO
        if args.audio_backbone == 'resnet10':
            model = torchvision.models.resnet._resnet(torchvision.models.resnet.BasicBlock, [1, 1, 1, 1], weights=None, progress=False)
        elif args.audio_backbone == 'resnet18':
            model = torchvision.models.resnet18(weights=None)
        elif args.audio_backbone == 'resnet34':
            model = torchvision.models.resnet34(weights=None)
        elif args.audio_backbone == 'resnet50':
            model = torchvision.models.resnet50(weights=None)

        model.conv1 = torch.nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # model.fc = nn.Linear(model.fc.in_features, pr.audio_feature_dim)
        model.fc = nn.Linear(model.fc.in_features, n_clf)
        
        
        # Initialize weights
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.normal_(m.weight, mean=1, std=0.02)
                nn.init.constant_(m.bias, 0)
        return model
