%% SETUP
clc
clear
close all

data_path = 'C:\Users\remelli\Desktop\fitting_cpp_initialization\edoardo\stage1\';

Rx = @(alpha) [
    1, 0, 0;
    0, cos(alpha), -sin(alpha);
    0, sin(alpha), cos(alpha)];

Ry = @(alpha) [
    cos(alpha), 0, sin(alpha);
    0, 1, 0;
    -sin(alpha), 0, cos(alpha)];

Rz = @(alpha) [
    cos(alpha), -sin(alpha), 0;
    sin(alpha), cos(alpha), 0
    0, 0, 1];

% DEFINE USEFUL TRANSFORMS
% transform from my orientation to htrack's one
R = Rx(-pi/2)*Rz(-pi);
% transform from htrack's orientation to mine
Rinv = Rz(pi) * Rx(pi/2);

% LOAD TEMPLATE HAND MODEL
load('hand_model.mat');
[hand_model] = reindex_fullhand(hand_model);


% choose frames from folder
frames = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16};
% frames = {11};

%% READ htrack INPUT

v_hand_model = {};
v_data_points = {};

for i = 1 : length(frames)

    frame_path = [data_path, num2str(frames{i}) ,'\'];
    %read cpp model
    [centers, radii, blocks, theta, phalanges, mean_centers] = read_cpp_model(frame_path);

    v_hand_model{i} = hand_model;
    %READ INITIAL TRANSFORMATIONS from htrack
    v_hand_model{i}.segments{1}{1}.local(1:3,1:3) = Rinv* phalanges{2}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{1}{2}.local(1:3,1:3) = Rinv* phalanges{3}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{1}{3}.local(1:3,1:3) = Rinv* phalanges{4}.local(1:3,1:3)*R;

    v_hand_model{i}.segments{5}{1}.local(1:3,1:3) = Rinv*  phalanges{5}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{5}{2}.local(1:3,1:3) =  Rinv*  phalanges{6}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{5}{3}.local(1:3,1:3) =  Rinv*  phalanges{7}.local(1:3,1:3)*R;

    v_hand_model{i}.segments{4}{1}.local(1:3,1:3) = Rinv*  phalanges{8}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{4}{2}.local(1:3,1:3) =  Rinv*  phalanges{9}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{4}{3}.local(1:3,1:3) =  Rinv*  phalanges{10}.local(1:3,1:3)*R;

    v_hand_model{i}.segments{3}{1}.local(1:3,1:3) = Rinv*  phalanges{11}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{3}{2}.local(1:3,1:3) =  Rinv*  phalanges{12}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{3}{3}.local(1:3,1:3) =  Rinv*  phalanges{13}.local(1:3,1:3)*R;

    v_hand_model{i}.segments{2}{1}.local(1:3,1:3) = Rinv*  phalanges{14}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{2}{2}.local(1:3,1:3) =  Rinv*  phalanges{15}.local(1:3,1:3)*R;
    v_hand_model{i}.segments{2}{3}.local(1:3,1:3) =  Rinv*  phalanges{16}.local(1:3,1:3)*R;

    % POSE MODEL
    dtheta = { +[-theta(11); +theta(10); -theta(12:13)],[-theta(15); +theta(14); -theta(16:17)],-[+theta(19); -theta(18); +theta(20:21)],-[+theta(23); -theta(22); +theta(24:25)],-[+theta(27); -theta(26);+ theta(28:29)]};
    v_hand_model{i}.global_pose = [Rinv,zeros(3,1),;zeros(1,3),1] * makehgtform('translate',theta(1:3))*makehgtform('axisrotate', [1;0;0], theta(4))*makehgtform('axisrotate', [0;1;0], theta(5))*makehgtform('axisrotate', [0;0;1], theta(6))*[R,zeros(3,1),;zeros(1,3),1];
    v_hand_model{i}.global_translation = v_hand_model{i}.global_pose(1:3,4);
    % this should be improved, but should not really cause any problem like that
    v_hand_model{i}.global_rotation = [-theta(4);theta(5);theta(6)];
    v_hand_model{i}.theta = dtheta;
    pose.global_translation = zeros(3,1);
    pose.global_rotation = zeros(3,1);
    v_hand_model{i} = pose_model(v_hand_model{i}, dtheta, pose);

    % READ POINT CLOUD

    % camera parameters
    tx = 640 / 4; ty = 480 / 4; fx = 287.26; fy = 287.26;

    filename = [frame_path, 'depth.png']; D = imread(filename);
    filename = [frame_path, 'mask.png']; M = imread(filename);

    % set depth value of pixels not belonging to mask to zero
    D(M == 0) = 0;
    [U, V] = meshgrid(1:size(D, 2), 1:size(D, 1));
    UVD = zeros(size(D, 1), size(D, 2), 3);
    UVD(:, :, 1) = U;
    UVD(:, :, 2) = V;
    UVD(:, :, 3) = D;
    uvd = reshape(UVD, size(UVD, 1) * size(UVD, 2), 3)';
    I = convert_uvd_to_xyz(tx, ty, fx, fy, uvd);
    data_points = {};
    for j = 1:size(I, 2)
        if ~any(isnan(I(:, j)))
            data_points{end + 1} = I(:, j);
        end
    end

    %% FILTER DATA
    depth_image = reshape(I, 3, ty * 2, tx * 2);
    depth_image = shiftdim(depth_image, 1);
    depth = depth_image(:, :, 3);
    max_depth = max(depth(:));
    depth = depth ./ max_depth;
    depth = bfilter2(depth, 5, [2 0.1]);
    depth = depth .* max_depth;
    depth_image(:, :, 3) = depth;
    depth_image = shiftdim(depth_image, 2);
    I2 = reshape(depth_image, 3, ty * 2 * tx * 2);
    data_points = {};
    for j = 1:size(I2, 2)
        if ~any(isnan(I2(:, j)))
            data_points{end + 1} = I2(:, j);
        end
    end
    % ROTATE POINT CLOUD
    for j = 1:length(data_points)
        data_points{j} = Rinv*data_points{j};
    end
    % save pc
    v_data_points{i} = data_points;
end

%% VISUALIZE TO CHECK EVERYTHING IS ALLRIGHT

for i = 1 : length(frames)
    [v_indices{i}, v_model_points{i}, v_block_indices{i}] = compute_projections(v_data_points{i}, v_hand_model{i});
    display_model(v_hand_model{i}, 1.0, 'big')
    mypoints(v_data_points{i}, [0.3, 0.8, 0.3], 4);
end

%% START OPTIMIZATION

% OPTIMIZATION PARAMETERS
tol = 1e-03;
settings.fov = 15;
settings.H = 480/2;
settings.W = 636/2;
settings.D = 3;
settings.sparse_data = false;
settings.RAND_MAX = 32767;
settings.side = 'front';
settings.view_axis = 'Y';
settings.theta_factor = 0.2;
settings.block_safety_factor = 1.1;
w_bounds = 400;
w_val = 2000;
w_stabilization_theta = 300;
w_stabilization_radii = 1000;

%% REDUCED PARAMETERS OPTIMIZATION

lambda = 100.0;
n_iter = 0;
target_delta = 2.0;
err = tol + 1;

disp('Optimization over reduced parameters')

while err > tol && n_iter < 20
    
    n_iter = n_iter + 1
    
     % compute Jacobian matrix and right hand side
    J_full = [];
    F_full = [];
    
    for i = 1 :length(frames)
        % compute Jacobian matrix and right hand side
        [F_d2m, Jtheta_d2m, Jbeta_d2m, Jr_d2m, Jcenters_fingers_d2m, Jcenters_palm_d2m, Jr_palm_d2m, Js_membrane_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m] = COMPUTE_D2M_ENERGY(v_hand_model{i}, v_data_points{i}, v_model_points{i}, v_indices{i}, v_block_indices{i},true,false );
        [F_m2d, Jtheta_m2d, Jbeta_m2d, Jr_m2d, Jcenters_fingers_m2d, Jr_palm_m2d, Jcenters_palm_m2d, Js_membrane_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d] = COMPUTE_JACOBIAN_M2D(v_hand_model{i}, v_data_points{i}, settings,false );
        % start by stacking full jacobian
        J = [Jr_d2m, Jr_palm_d2m, Jbeta_d2m, Jcenters_palm_d2m, Jcenters_fingers_d2m,zeros(length(F_d2m),26*(i-1)), Jtheta_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m,zeros(length(F_d2m),26*(length(frames)-i)); Jr_m2d, Jr_palm_m2d, Jbeta_m2d, Jcenters_palm_m2d, Jcenters_fingers_m2d,zeros(length(F_m2d),26*(i-1)),Jtheta_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d,zeros(length(F_m2d),26*(length(frames)-i)) ] ;
        F = [F_d2m;F_m2d];

        param_value = zeros( size(J));

        for j = 1:length(v_hand_model{i}.finger_radii)
            param_value(:,j) = v_hand_model{i}.finger_radii{j};
        end
        l1 = length(v_hand_model{i}.finger_radii);
        for j = 1:length(v_hand_model{i}.palm_wrist_radii)
           param_value(:,l1 + j) = v_hand_model{i}.palm_wrist_radii{j};
        end
        l1 = l1 + length(v_hand_model{i}.palm_wrist_radii);
        param_value(:, l1 + 1: l1 + 1 + length(v_hand_model{i}.beta)*length(v_hand_model{i}.beta{2})) = [ v_hand_model{i}.beta{1}', v_hand_model{i}.beta{2}',  v_hand_model{i}.beta{3}', v_hand_model{i}.beta{4}', v_hand_model{i}.beta{5}' ].* ones(length(F),16);
        l2 = l1 + length(v_hand_model{i}.beta)*length(v_hand_model{i}.beta{2}) + 1;

        for j = 1:length( v_hand_model{i}.palm_wrist_centers_relative)
            for k = 1:3
                param_value(:,l2 + 3*(j-1) + k) = v_hand_model{i}.palm_wrist_centers_relative{j}(k);
            end
        end
        l3 = l2 + 3*length( v_hand_model{i}.palm_wrist_centers_relative);
        for j = 1:length(v_hand_model{i}.segments)
            for k = 1:3
            param_value(:,l3 + 3*(j-1) + k) = v_hand_model{i}.segments{j}{1}.local(k,4);
            end
        end
        l3 = l3 + 3*length(v_hand_model{i}.segments);

        param_value(:,l3+1:end) = 1;

        J = J.*param_value;
        J = [ sum(J(:,1:l1),2), sum(J(:,1+l1:l2),2), sum(J(:,1+l2:3:l3),2), sum(J(:,2+l2:3:l3),2), sum(J(:,3+l2:3:l3),2), J(:,l3+1:end); ];
        
        J_full = [J_full;J];
        F_full = [F_full;F];      
    end
    
    % add stabilization energy
    J_full = [J_full; zeros(length(frames)*26,5),w_stabilization_theta*eye(length(frames)*26,length(frames)*26); w_stabilization_radii,zeros(1,4+26*length(frames))];
    F_full = [F_full; zeros(26*length(frames) +1,1)];   
        
        
    % perform descent step      
    JtJ = J_full' * J_full;
    LHS = JtJ + lambda*eye(size(JtJ));
    delta = LHS \ (J_full' * F_full);

    err = norm(delta)
    lambda = target_delta/err * lambda;


    delta_radii = delta(1);
    delta_beta = delta(2);
    delta_alpha = delta(3:5);
    
    delta_theta = {};
    delta_pose = {};
    
    for i = 1 :length(frames)
        delta_theta{i} = { delta(26*(i-1)+6:26*(i-1)+9), delta(26*(i-1)+10:26*(i-1)+13), delta(26*(i-1)+14:26*(i-1)+17), delta(26*(i-1)+18:26*(i-1)+21), delta(26*(i-1)+22:26*(i-1)+25)};
        delta_pose{i}.global_rotation = delta(26*(i-1)+26:26*(i-1)+28);
        delta_pose{i}.global_translation = delta(26*(i-1)+29:26*(i-1)+31);
    end

    % update
    for i = 1:length(frames)
    
        for j = 1:length(v_hand_model{i}.finger_radii)
            v_hand_model{i}.finger_radii{j} = v_hand_model{i}.finger_radii{j}*(1 + delta_radii);
        end

        for j = 1:length(v_hand_model{i}.palm_wrist_radii)
            v_hand_model{i}.palm_wrist_radii{j} = v_hand_model{i}.palm_wrist_radii{j}*(1 + delta_radii);
            v_hand_model{i}.palm_wrist_centers_relative{j} = v_hand_model{i}.palm_wrist_centers_relative{j}.*(1 + delta_alpha);
        end

        v_hand_model{i}.beta = { v_hand_model{i}.beta{1}*(1 + delta_beta), v_hand_model{i}.beta{2}*(1 + delta_beta),  v_hand_model{i}.beta{3}*(1 + delta_beta), v_hand_model{i}.beta{4}*(1 + delta_beta), v_hand_model{i}.beta{5}*(1 + delta_beta) };

        delta_finger_center = {delta_alpha.*v_hand_model{i}.segments{1}{1}.local(1:3,4), delta_alpha.*v_hand_model{i}.segments{2}{1}.local(1:3,4), delta_alpha.*v_hand_model{i}.segments{3}{1}.local(1:3,4), delta_alpha.*v_hand_model{i}.segments{4}{1}.local(1:3,4), delta_alpha.*v_hand_model{i}.segments{5}{1}.local(1:3,4)};

        % update parameters
        for j = 1:length(v_hand_model{i}.theta)
            v_hand_model{i}.theta{j} = v_hand_model{i}.theta{j} + delta_theta{i}{j};
        end

        % update centers
        [ v_hand_model{i} ] = update_centers(v_hand_model{i});
        % update shape
        [ v_hand_model{i}.segments ] = update_fingers_shape(v_hand_model{i}.segments, v_hand_model{i}.beta ); 

        % update model pose
        [ v_hand_model{i} ] = update_fingers_pose(v_hand_model{i}, delta_theta{i}, delta_finger_center );
        delta_theta{i} = {zeros(4,1), zeros(4,1), zeros(4,1), zeros(4,1), zeros(4,1)};
        [ v_hand_model{i} ] = pose_model(v_hand_model{i}, delta_theta{i},delta_pose{i} );

        % update membranes
        [ v_hand_model{i} ] = update_membranes( v_hand_model{i} );   

        % compute updated correspondencies
        [v_indices{i}, v_model_points{i}, v_block_indices{i}] = compute_projections(v_data_points{i}, v_hand_model{i});
    end   
        
end

%% FULL OPTIMIZATION

lambda = 100.0;
n_iter = 0;
target_delta = 1.0;
err = tol + 1;
w_stabilization_radii = w_stabilization_radii/2;




% compute Jacobian matrix and right hand side
    J_full = [];
    F_full = [];
    
    for i = 1:length(frames)
    
        % compute Jacobian matrix and right hand side
        [F_d2m, Jtheta_d2m, Jbeta_d2m, Jr_d2m, Jcenters_fingers_d2m, Jcenters_palm_d2m, Jr_palm_d2m, Js_membrane_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m] = COMPUTE_D2M_ENERGY(v_hand_model{i}, v_data_points{i}, v_model_points{i}, v_indices{i}, v_block_indices{i},true,false );
        [F_m2d, Jtheta_m2d, Jbeta_m2d, Jr_m2d, Jcenters_fingers_m2d, Jr_palm_m2d, Jcenters_palm_m2d, Js_membrane_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d] = COMPUTE_JACOBIAN_M2D(v_hand_model{i}, v_data_points{i}, settings,false );

        J = [Jbeta_d2m, Jr_d2m, Jcenters_palm_d2m, Jr_palm_d2m, Jcenters_fingers_d2m, zeros(length(F_d2m),26*(i-1)), Jtheta_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m,zeros(length(F_d2m),26*(length(frames)-i)); Jbeta_m2d, Jr_m2d, Jcenters_palm_m2d, Jr_palm_m2d, Jcenters_fingers_m2d, zeros(length(F_m2d),26*(i-1)), Jtheta_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d,zeros(length(F_m2d),26*(length(frames)-i))] ;
        F = [F_d2m;F_m2d;];
        
        % store in full matrix
        J_full = [J_full;J];
        F_full = [F_full;F];       
    end
    
    % add stabilization energies
    J_full = [J_full; zeros(length(frames)*26,104),w_stabilization_theta*eye(length(frames)*26,length(frames)*26); w_stabilization_radii*[zeros(21,16),eye(21,21),zeros(21,67),zeros(21,length(frames)*26)];[zeros(5,89), w_stabilization_radii*eye(5,15), zeros(5,length(frames)*26)] ];
    F_full = [F_full; zeros(26*length(frames),1); zeros(21,1);zeros(5,1)];  


disp('Full optimization')

while err > tol && n_iter < 6
    
    v_hand_model_backup = v_hand_model;
    
    n_iter = n_iter + 1
               
    % perform descent step      
    JtJ = J_full' * J_full;
    LHS = JtJ + lambda*eye(size(JtJ));
    delta = LHS \ (J_full' * F_full);
        
    %err = norm(delta)
    %lambda = err/target_delta * lambda;
       
    delta_radii = delta(17:37);
    delta_centers = delta (38: 76);
    delta_radii_palm = delta(77 : 89);
    delta_finger_center_vec = delta( 90 : 104 );
    
    delta_theta = {};
    delta_pose = {};
    for i = 1 :length(frames)
        
        delta_theta{i} = { delta(26*(i-1)+105:26*(i-1)+108), delta(26*(i-1)+109:26*(i-1)+112), delta(26*(i-1)+113:26*(i-1)+116), delta(26*(i-1)+117:26*(i-1)+120), delta(26*(i-1)+121:26*(i-1)+124)};
        delta_pose{i}.global_rotation = delta(26*(i-1)+125:26*(i-1)+127);
        delta_pose{i}.global_translation = delta(26*(i-1)+128:26*(i-1)+130);
    end
    
    %update stuff
    for i = 1 :length(frames)
        
         v_hand_model{i}.beta = { v_hand_model{i}.beta{1} + delta(1:4), v_hand_model{i}.beta{2} + delta(5:7),  v_hand_model{i}.beta{3} + delta(8:10), v_hand_model{i}.beta{4} + delta(11:13), v_hand_model{i}.beta{5} + delta(14:16) };
        % update parameters
        for j = 1:length(v_hand_model{i}.theta)
            v_hand_model{i}.theta{j} = v_hand_model{i}.theta{j} + delta_theta{i}{j};
        end

        for j = 1:length(v_hand_model{i}.finger_radii)
            v_hand_model{i}.finger_radii{j} = max(v_hand_model{i}.finger_radii{j} + delta_radii(j),0.1);
        end

        for j = 1:length(v_hand_model{i}.palm_wrist_radii)
            v_hand_model{i}.palm_wrist_radii{j} = v_hand_model{i}.palm_wrist_radii{j} + delta_radii_palm(j);
            v_hand_model{i}.palm_wrist_centers_relative{j} = v_hand_model{i}.palm_wrist_centers_relative{j} +  delta_centers( (3*(j-1) +1) : 3*j);
        end
        delta_finger_center = {delta_finger_center_vec(1:3), delta_finger_center_vec(4:6), delta_finger_center_vec(7:9), delta_finger_center_vec(10:12), delta_finger_center_vec(13:15)};


        % update centers
        [ v_hand_model{i} ] = update_centers(v_hand_model{i});
        % update shape
        [ v_hand_model{i}.segments ] = update_fingers_shape(v_hand_model{i}.segments, v_hand_model{i}.beta );    
        % update model pose
        [ v_hand_model{i} ] = update_fingers_pose(v_hand_model{i}, delta_theta{i}, delta_finger_center );
        delta_theta{i}= {zeros(4,1), zeros(4,1), zeros(4,1), zeros(4,1), zeros(4,1)};
        [ v_hand_model{i} ] = pose_model(v_hand_model{i}, delta_theta{i},delta_pose{i} );
        % update membranes
        [ v_hand_model{i} ] = update_membranes( v_hand_model{i} );  
        [v_hand_model{i}] = reindex_fullhand(v_hand_model{i});
        % compute updated correspondencies
        [v_indices{i}, v_model_points{i}, v_block_indices{i}] = compute_projections(v_data_points{i}, v_hand_model{i});
    end    
    
    
    % compute new Jacobians and stuff
    
    % compute Jacobian matrix and right hand side
    J_full_new = [];
    F_full_new = [];
    
    for i = 1:length(frames)
    
        % compute Jacobian matrix and right hand side
        [F_d2m, Jtheta_d2m, Jbeta_d2m, Jr_d2m, Jcenters_fingers_d2m, Jcenters_palm_d2m, Jr_palm_d2m, Js_membrane_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m] = COMPUTE_D2M_ENERGY(v_hand_model{i}, v_data_points{i}, v_model_points{i}, v_indices{i}, v_block_indices{i},true,false );
        [F_m2d, Jtheta_m2d, Jbeta_m2d, Jr_m2d, Jcenters_fingers_m2d, Jr_palm_m2d, Jcenters_palm_m2d, Js_membrane_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d] = COMPUTE_JACOBIAN_M2D(v_hand_model{i}, v_data_points{i}, settings,false );

        J = [Jbeta_d2m, Jr_d2m, Jcenters_palm_d2m, Jr_palm_d2m, Jcenters_fingers_d2m, zeros(length(F_d2m),26*(i-1)), Jtheta_d2m, Jglobal_rotation_d2m, Jglobal_translation_d2m,zeros(length(F_d2m),26*(length(frames)-i)); Jbeta_m2d, Jr_m2d, Jcenters_palm_m2d, Jr_palm_m2d, Jcenters_fingers_m2d, zeros(length(F_m2d),26*(i-1)), Jtheta_m2d, Jglobal_rotation_m2d, Jglobal_translation_m2d,zeros(length(F_m2d),26*(length(frames)-i))] ;
        F = [F_d2m;F_m2d;];
        
        % store in full matrix
        J_full_new = [J_full_new;J];
        F_full_new = [F_full_new;F];       
    end
    
    % add stabilization energy
    J_full_new = [J_full_new; zeros(length(frames)*26,104),w_stabilization_theta*eye(length(frames)*26,length(frames)*26); w_stabilization_radii*[zeros(21,16),eye(21,21),zeros(21,67),zeros(21,length(frames)*26)];[zeros(5,89), w_stabilization_radii*eye(5,15), zeros(5,length(frames)*26)] ];
    F_full_new = [F_full_new; zeros(26*length(frames),1); zeros(21,1);zeros(5,1)];  
    
    if (norm(F_full_new) < norm(F_full))
        lambda = lambda/4;
        J_full = J_full_new;
        F_full = F_full_new;
    else
        lambda = lambda *3;
        v_hand_model = v_hand_model_backup;
        disp('rejected step');
    end
        
    
end



%% DISPLAY RESULTS


for i = 1 : length(frames)
    [v_indices{i}, v_model_points{i}, v_block_indices{i}] = compute_projections(v_data_points{i}, v_hand_model{i});
    display_model(v_hand_model{i}, 1.0, 'big')
    mypoints(v_data_points{i}, [0.3, 0.8, 0.3], 4);
    mypoints(v_model_points{i}, [0.3, 0.3, 0.8], 4);
    mylines(v_data_points{i}, v_model_points{i}, [0.85, 0.85, 0.85]);
end


% save for storing later
 save('calib.mat','v_hand_model')
 save('phalanges.mat','phalanges')
 save('data_path.mat','data_path');