%% adding paths
addpath(genpath(pwd))
%% computing acquisition geometry
experimental_setup
load('sub_Sig.mat');
N = round(60 * 60 / 512 / 512 * 1E7);
load('C_i.mat');
load('C_d.mat');
sub_Sig(sub_Sig < 0) = 0;
num_meas = size(sub_Sig,1);
nped = sqrt(size(sub_Sig,1));
[Arr_bin_i,Arr_bin_d, Des_value_i, Des_value_d, min_bin, max_bin,...
xmin_i, ymin_i, zmin_i, xmin_d, ymin_d, zmin_d] =...
cpt_geo_a(c,delta_t, xbc,ybc,zbc, bx,by,bz, xec,yec,zec, C_i, C_d);
sub_Sig = Sig_alignment(sub_Sig,max_bin);
Sig_flag = ones(num_meas,1); Sig_weight = Sig_flag;
%% 1 Initialization
% 1.1 Initializing tau
tau = sub_Sig / N;
temp_sum = (sub_Sig - N) .* log(1 - tau); obj_data = sum(temp_sum(:),'omitnan');
temp_sum =  - sub_Sig .* log(tau); obj_data = obj_data + sum(temp_sum(:),'omitnan');
clear temp_sum
% 1.3 Initializing u
u_BP = NLOS_adjoint_a(Sig_flag, xmin_i,ymin_i,zmin_i, xmin_d,ymin_d,zmin_d, xg,yg,zg,...
                      C_i, C_d, Arr_bin_i,Arr_bin_d, Des_value_i, Des_value_d, tau);
[u, r_err] = NLOS_reconstruction_a(Sig_weight, Sig_flag,xmin_i,ymin_i,zmin_i,xmin_d,ymin_d,zmin_d, xg,yg,zg,...
                                   max_bin, C_i,C_d, Arr_bin_i,Arr_bin_d, Des_value_i,Des_value_d, zeros(lxs,lys,lzs), u_BP, 20, 5E-3, 0);
b_u = forward_a(Sig_flag, C_i,C_d, u, max_bin,xmin_i,ymin_i,zmin_i, xmin_d,ymin_d,zmin_d,xg,yg,zg,Arr_bin_i,Arr_bin_d,Des_value_i,Des_value_d);
data_error = sum((tau(:) - b_u(:)).^2); obj_L1_norm = sum(abs(u(:)));
s_u = 10 * data_error / obj_L1_norm;
mu = 0.5 * s_u / (mean(abs(u(:))));
b = zeros(lxs,lys,lzs);
num_L1_iter = 20;
Err_uv = NaN * zeros(1,num_L1_iter);
for L1_iter = 1: num_L1_iter
    v = Soft_thre(s_u / mu, u - b); v(v<0) = 0;
    [u, r_err_L1] = NLOS_reconstruction_a(Sig_weight, Sig_flag,xmin_i,ymin_i,zmin_i,xmin_d,ymin_d,zmin_d, xg,yg,zg,...
    max_bin, C_i,C_d, Arr_bin_i,Arr_bin_d, Des_value_i,Des_value_d, v, u_BP + mu * (v + b), 20, 5E-3, mu);
    Err_uv(L1_iter) = norm(u(:) - v(:)) / norm(u(:));
    if  Err_uv(L1_iter) < 5E-3
        fprintf('Initializing u: At iteration %d, R(u,v) = %1.6f\n', L1_iter, Err_uv(L1_iter));
        break
    end
    b = b + v - u;
end
u = v;
%% 2 Iteration
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 2.1.p
Num_loop = 2;
% 2.2.p DDTF_bm
lambda_pu = 20; lambda_RMSE = 50; sparse_value = 10; sparse_threshold = 4;
pxo = 2; pyo = 3; pzo = 3; sxo = 1; syo = 1; szo = 1; ws = 4; nno = 3;
% 2.3.p DDTF
lambda_ut_imp = 1; lambda_pt = 40; pxt = 3; pyt = 3; pzt = 3; sxt =1; syt = 1; szt = 1;
% 2.4.p data fidelity
lambda_imp = 0.5; lambda_t_imp = 0.5;
% 2.5.p update u
k_sparse = 1; lambda_u_imp = 0.5; lambda_g_imp = 10; num_Main_L1_iter = 3;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for num_loop = 1:Num_loop
    u(1:3,:,:) = 0;
    % 2.1 Updating g
    g = get_pseudo_surface(u);
    % 2.2 Updating uDC
    denoise_factor = 255 / max(u(:)); 
    [uDC, Dxy, Dz] = DDTF_bm(u, denoise_factor, lambda_pu,...
                        lambda_RMSE, pxo, pyo, pzo, sxo, syo, szo,...
                        ws, nno, sparse_value, sparse_threshold);
    g = max(uDC(:)) / max(g(:)) * g;
    % 2.3 Updating S
    b_u = forward_a(Sig_flag, C_i,C_d, u, max_bin,xmin_i,ymin_i,zmin_i, xmin_d,ymin_d,zmin_d,xg,yg,zg,Arr_bin_i,Arr_bin_d,Des_value_i,Des_value_d);
    temp_Sig = (tau + lambda_ut_imp * b_u) / (1 + lambda_ut_imp);
    if num_loop == 1
        amp_factor = 255 / max(temp_Sig(:));
    end
    [S,~] = DDTF3D(amp_factor, reshape(temp_Sig,nped,nped,[]), lambda_pt, pxt,pyt,pzt, sxt,syt,szt);
    S = reshape(S,nped^2,[]);
    % 2.4 Updating tau
    temp_Sig = (lambda_t_imp * S + lambda_imp * b_u) / (lambda_t_imp + lambda_imp);
    if num_loop == 1
        obj_reg = sum((tau(:) - temp_Sig(:)).^2);
        lambda = lambda_imp * obj_data / obj_reg;
        lambda_t = lambda_t_imp * obj_data / obj_reg;
    end
    tau = update_tau(sub_Sig, N, lambda_t + lambda, temp_Sig);
    % 2.5 Updating u
    if num_loop == 1
        err_g = sum((u(:) - g(:)).^2);
        err_uDC = sum((u(:) - uDC(:)).^2);
        data_error = sum((tau(:) - b_u(:)).^2);
        data_error_S = sum((S(:) - b_u(:)).^2);
        lambda_g = lambda * lambda_g_imp * data_error / err_g;
        lambda_u = lambda * lambda_u_imp * data_error / err_uDC;
        lambda_ut = lambda_ut_imp * lambda_t;
        s_u = k_sparse * (lambda * data_error + lambda_ut * data_error_S + lambda_u * err_uDC + lambda_g * err_g) / sum(abs(u(:)));
        mu = s_u * nnz(u) / sum(abs(u(:))) / 2;
    end
    RHS_BP = NLOS_adjoint_a(Sig_flag, xmin_i,ymin_i,zmin_i, xmin_d,ymin_d,zmin_d, xg,yg,zg,...
                              C_i, C_d, Arr_bin_i,Arr_bin_d, Des_value_i, Des_value_d, lambda * tau + lambda_ut * S);           
    Err_uv_main = NaN * zeros(1,num_Main_L1_iter);
    b = zeros(lxs,lys,lzs); 
    for main_L1_iter = 1:num_Main_L1_iter
        % update v
        v = Soft_thre(s_u / mu, u - b); v(v < 0) = 0;
        % update u
        [u, r_err_main_L1] = NLOS_reconstruction_a((lambda + lambda_ut) * Sig_weight, Sig_flag,...
                             xmin_i,ymin_i,zmin_i,xmin_d,ymin_d,zmin_d,...
                             xg,yg,zg,max_bin, C_i,C_d, Arr_bin_i,...
                             Arr_bin_d, Des_value_i,Des_value_d,...
                             v, RHS_BP + mu * (v + b) + lambda_g * g + lambda_u * uDC,...
                             20, 5E-3, mu + lambda_g + lambda_u);
        Err_uv_main(main_L1_iter) = norm(u(:) - v(:)) / norm(u(:));
        if Err_uv_main(main_L1_iter) < 5E-3
           fprintf('Update u in main loop: At iteration %d, R(u,v) = %1.6f\n', main_L1_iter, Err_uv_main(main_L1_iter));
           break
        end   
        b = b + v - u;
    end
    u = v;
end
view_albedo(u, 'SSCR', 0)
save u u