/*
 * Globally-Optimal Gaussian Mixture Alignment (GOGMA): gogma
 * A class for Globally-Optimal Gaussian Mixture Alignment. Branches
 * a hyperoctree representation of the transformation space SE(3)
 * and bounds each branch on the GPU. Uses the VNL library for local
 * optimisation.
 *
 * Campbell, D., and Petersson, L., "GOGMA: Globally-Optimal Gaussian
 * Mixture Alignment", IEEE Conference on Computer Visision and
 * Pattern Recognition (CVPR), Las Vegas, USA, IEEE, Jun. 2016
 *
 * For the full license, see license.txt in the root directory
 *
 * Author: Dylan Campbell
 * Date: 20160212
 * Revision: 1.1
 */
 
#include "gogma.h"

#define MAX_QUEUE_SIZE 5e8
#define CudaErrorCheck(ans) {__CudaErrorCheck((ans), __FILE__, __LINE__);}

// Declare Constant Memory
__constant__ int c_bits_per_dim;
__constant__ int c_nodes_per_dim;
__constant__ int c_dimension;
__constant__ int c_num_components_x;
__constant__ int c_num_components_y;
__constant__ float c_sigma_x;
__constant__ float c_sigma_y;
__constant__ Node c_node_parent;

void __CudaErrorCheck(cudaError_t code, const char* file, int line) {
	if (code != cudaSuccess) {
		std::cout << "CUDA Error (" << file << ":" << line << "): " << cudaGetErrorString(code) << std::endl;
		exit(code);
	}
}

__global__ void GetBounds(float* d_mu_x_, float* d_mu_y_, float* d_phi_x_, float* d_phi_y_, float* d_lower_bounds_, float* d_upper_bounds_) {
	// Calculate the thread global_index
	int global_index = blockIdx.x * blockDim.x + threadIdx.x;

	// Calculate the subcube translation and rotation widths
	float tw = c_node_parent.tw / c_nodes_per_dim;
	float rw = c_node_parent.rw / c_nodes_per_dim;

	// Calculate the subcube centre
	// global_index has the structure: {rz/rw ry/rw rx/rw tz/tw ty/tw tx/tw}
	// mask selects the c_bits_per_dim least significant bits from global_index
	int mask = c_nodes_per_dim - 1;
	float tx = c_node_parent.tx + tw / 2 + tw * (global_index & mask);
	float ty = c_node_parent.ty + tw / 2 + tw * ((global_index >> (1 * c_bits_per_dim)) & mask);
	float tz = c_node_parent.tz + tw / 2 + tw * ((global_index >> (2 * c_bits_per_dim)) & mask);
	float rx = c_node_parent.rx + rw / 2 + rw * ((global_index >> (3 * c_bits_per_dim)) & mask);
	float ry = c_node_parent.ry + rw / 2 + rw * ((global_index >> (4 * c_bits_per_dim)) & mask);
	float rz = c_node_parent.rz + rw / 2 + rw * ((global_index >> (5 * c_bits_per_dim)) & mask);

	// Skip subcube if the centre is outside the rotation PI-ball
	float t = sqrt(rx * rx + ry * ry + rz * rz);
	if (t - SQRT3 * rw / 2 > PI) {
		d_lower_bounds_[global_index] = -INF;
		d_upper_bounds_[global_index] = +INF;
		return;
	}

	// Convert angle-axis rotation into a rotation matrix
	float v[3] = {rx / t,
								ry / t,
								rz / t};
	float ct = cos(t);
	float ct2 = 1 - ct;
	float st = sin(t);
	float v0st = v[0] * st;
	float v1st = v[1] * st;
	float v2st = v[2] * st;
	float v0v1ct2 = v[0] * v[1] * ct2;
	float v0v2ct2 = v[0] * v[2] * ct2;
	float v1v2ct2 = v[1] * v[2] * ct2;
	float R[3][3] = {{ct + v[0] * v[0] * ct2, v0v1ct2 - v2st, v0v2ct2 + v1st},
									 {v0v1ct2 + v2st, ct + v[1] * v[1] * ct2, v1v2ct2 - v0st},
									 {v0v2ct2 - v1st, v1v2ct2 + v0st, ct + v[2] * v[2] * ct2}};

	// Calculate subcube translation uncertainty radius
	float uncertainty_radius_translation = SQRT3 * tw / 2;

	// Calculate cosA
	float cosA = SQRT3 * rw / 2;
	if (cosA > PI) {
		cosA = PI;
	}
	cosA = cos(cosA);

	// For every component in GMM Y
	float ub = 0;
	float lb = 0;
	for (int i = 0; i < c_num_components_y; ++i) {
		// Rotate and translate GMM parameters
		float py[3] = {d_mu_y_[c_dimension * i],
									 d_mu_y_[c_dimension * i + 1],
									 d_mu_y_[c_dimension * i + 2]};
		float py_transformed[3] = {R[0][0] * py[0] + R[0][1] * py[1] + R[0][2] * py[2] + tx,
															 R[1][0] * py[0] + R[1][1] * py[1] + R[1][2] * py[2] + ty,
															 R[2][0] * py[0] + R[2][1] * py[1] + R[2][2] * py[2] + tz};

		// Calculate terms that dependent on y
		float norm_y_squared = py[0] * py[0] + py[1] * py[1] + py[2] * py[2];
		float norm_y = sqrt(norm_y_squared);

		// For every component in GMM X
		for (int j = 0; j < c_num_components_x; ++j) {
			float px[3] = {d_mu_x_[c_dimension * j],
										 d_mu_x_[c_dimension * j + 1],
										 d_mu_x_[c_dimension * j + 2]};

			// Calculate terms that dependent on x and y
			float norm_xy_squared = (px[0] - py_transformed[0]) * (px[0] - py_transformed[0]) + (px[1] - py_transformed[1]) * (px[1] - py_transformed[1]) + (px[2] - py_transformed[2]) * (px[2] - py_transformed[2]);
			float norm_x_squared = (px[0] - tx) * (px[0] - tx) + (px[1] - ty) * (px[1] - ty) + (px[2] - tz) * (px[2] - tz);
			float norm_x = sqrt(norm_x_squared);
			float cosB = (norm_x_squared + norm_y_squared - norm_xy_squared) / (2 * norm_x * norm_y);

			// Calculate upper bound
			// Uses distance from point xj to point yj transformed by the subcube centre
			ub -= d_phi_x_[j] * d_phi_y_[i] * exp(-norm_xy_squared / (2 * (c_sigma_x * c_sigma_x + c_sigma_y * c_sigma_y)));

			// Calculate lower bound
			// Uses distance from point xj to rotation surface
			// CASE A: point is within rotation cone
			float lower_bound_distance;
			if (cosB >= cosA) {
				lower_bound_distance = abs(norm_x - norm_y) - uncertainty_radius_translation;
			// CASE B: point is outside rotation cone
			} else {
				lower_bound_distance = sqrt(norm_x_squared + norm_y_squared - 2 * norm_x * norm_y * (cosB * cosA + sqrt(1 - cosB * cosB) * sqrt(1 - cosA * cosA))) - uncertainty_radius_translation;
			}
			if (lower_bound_distance > 0) {
				lb -= d_phi_x_[j] * d_phi_y_[i] * exp(-lower_bound_distance * lower_bound_distance / (2 * (c_sigma_x * c_sigma_x + c_sigma_y * c_sigma_y)));
			} else {
				lb -= d_phi_x_[j] * d_phi_y_[i];
			}
		}
	}
	// Update output arrays
	d_lower_bounds_[global_index] = lb;
	d_upper_bounds_[global_index] = ub;
}

GOGMA::GOGMA() {
  num_components_x_ = -1;
  num_components_y_ = -1;
  dimension_ = -1;
  optimality_certificate_ = 0;
  initialisation_type_ = 1;
  epsilon_ = 0.0;
  sigma_x_ = 0.0;
  sigma_y_ = 0.0;

  // Set initial transformation parameters to the null transform
	for (int i = 0; i < 7; ++i) {
		initial_theta_[i] = 0.0;
	}
	initial_theta_[3] = 1.0;

	for (int i = 0; i < 7; ++i) {
		theta_[i] = initial_theta_[i];
	}

  initial_node_.tx = -0.5;
  initial_node_.ty = -0.5;
  initial_node_.tz = -0.5;
  initial_node_.tw = 1.0;
  initial_node_.rx = -PI;
  initial_node_.ry = -PI;
  initial_node_.rz = -PI;
  initial_node_.rw = 2.0 * PI;
  initial_node_.lb = -1.0;
  initial_node_.ub = 0.0;

  cost_function_ = new L2CostFunction;
}

GOGMA::~GOGMA() {}

void GOGMA::Run() {
	Initialise();
	BranchAndBound();
	Clear();
}

void GOGMA::Initialise() {
	// Initialise transformation parameters
	for (int i = 0; i < 7; ++i) {
		theta_[i] = initial_theta_[i];
	}
	optimal_value_ = INF;
	optimal_node_ = initial_node_;
	optimality_certificate_ = 0;
	mu_y_transformed_ = new float[num_components_y_ * dimension_];
}

void GOGMA::BranchAndBound() {
	// Set constants
	int bits_per_dim = 2; // number of times each dimension is halved (2)
	int nodes_per_dim = pow(2, bits_per_dim); // number of nodes along each dimension (4)
	int num_dim = 6; // number of dimensions (3 rotation and 3 translation dimensions)
	int num_concurrent = pow(nodes_per_dim, num_dim); // number of threads (4,096)
	int num_threads_per_block = 64;
	int num_blocks = 64;

	// Calculate Initial and Post-GMA Error
	double initial_value = 0.0;
	double optimal_value_gma = 0.0;
	double theta[7];
	std::copy(theta_, theta_ + 7, theta);
	vnl_lbfgsb solver(*cost_function_);
	InitialiseOptimiser(solver);
	StartOptimisation(solver, theta, initial_value, optimal_value_gma); // Updates theta_
	optimal_value_ = initial_value;
	if (optimal_value_gma < optimal_value_) optimal_value_ = optimal_value_gma;
	optimal_node_ = initial_node_;
	std::cout << "Optimal Value: " << initial_value << " (Initial)" << std::endl;
	std::cout << "Optimal Value: " << optimal_value_gma << " (GMA)" << std::endl;

	// Push initial node to queue
	std::priority_queue<Node> queue;
	queue.push(initial_node_);

	// Set GPU to use and print its ID
	int device_id = 0;
	CudaErrorCheck(cudaSetDevice(device_id));

	// Allocate output arrays in host memory
	lower_bounds_ = new float[num_concurrent];
	upper_bounds_ = new float[num_concurrent];
	// Allocate input arrays in device memory
	CudaErrorCheck(cudaMalloc(&d_mu_x_, sizeof(float) * num_components_x_ * dimension_));
	CudaErrorCheck(cudaMalloc(&d_mu_y_, sizeof(float) * num_components_y_ * dimension_));
	CudaErrorCheck(cudaMalloc(&d_phi_x_, sizeof(float) * num_components_x_));
	CudaErrorCheck(cudaMalloc(&d_phi_y_, sizeof(float) * num_components_y_));
	// Allocate output arrays in device memory
	CudaErrorCheck(cudaMalloc(&d_lower_bounds_, sizeof(float) * num_concurrent));
	CudaErrorCheck(cudaMalloc(&d_upper_bounds_, sizeof(float) * num_concurrent));
	// Copy constants to device constant memory
	CudaErrorCheck(cudaMemcpyToSymbol(c_bits_per_dim, &bits_per_dim, sizeof(int)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_nodes_per_dim, &nodes_per_dim, sizeof(int)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_dimension, &dimension_, sizeof(int)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_num_components_x, &num_components_x_, sizeof(int)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_num_components_y, &num_components_y_, sizeof(int)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_sigma_x, &sigma_x_, sizeof(float)));
	CudaErrorCheck(cudaMemcpyToSymbol(c_sigma_y, &sigma_y_, sizeof(float)));
	// Copy input arrays to device
	CudaErrorCheck(cudaMemcpy(d_mu_x_, mu_x_, sizeof(float) * num_components_x_ * dimension_, cudaMemcpyHostToDevice));
	CudaErrorCheck(cudaMemcpy(d_mu_y_, mu_y_, sizeof(float) * num_components_y_ * dimension_, cudaMemcpyHostToDevice));
	CudaErrorCheck(cudaMemcpy(d_phi_x_, phi_x_, sizeof(float) * num_components_x_, cudaMemcpyHostToDevice));
	CudaErrorCheck(cudaMemcpy(d_phi_y_, phi_y_, sizeof(float) * num_components_y_, cudaMemcpyHostToDevice));

	// Keep exploring rotation space until convergence is achieved
	int num_iterations = 0;
	while(1) {
		// If the queue is empty, all regions have been explored and discarded within epsilon of the current optimal value
    if (queue.empty()) {
    	optimality_certificate_ = 2;
      std::cout << "Queue Empty" << std::endl;
      std::cout << "Optimal Value: " << std::setprecision(6) << optimal_value_ << std::endl;
      break;
    }
    // If the queue is full, break with sub-optimality
    if (queue.size() > MAX_QUEUE_SIZE) {
      std::cout << "Queue Full" << std::endl;
      std::cout << "Optimal Value: " << std::setprecision(6) << optimal_value_ << std::endl;
      break;
    }

    // Access rotation cube with lowest lower bound and remove it from the queue
    Node node_parent = queue.top();
		queue.pop();

    // Exit if the optimal_value_ is less than or equal to the lower bound plus epsilon (epsilon-suboptimality)
		if ((optimal_value_ - node_parent.lb) <= epsilon_) {
			std::cout << "Optimal Value: " << std::setprecision(6) << optimal_value_ << ", Lower Bound: " << std::setprecision(6) << node_parent.lb << std::endl;
			optimality_certificate_ = 1;
			break;
		}

		// Subdivide hypercube into subcubes and calculate upper and lower bounds for each
		CudaErrorCheck(cudaMemcpyToSymbol(c_node_parent, &node_parent, sizeof(Node)));
		GetBounds<<<num_blocks, num_threads_per_block>>>(d_mu_x_, d_mu_y_, d_phi_x_, d_phi_y_, d_lower_bounds_, d_upper_bounds_);
		CudaErrorCheck(cudaPeekAtLastError()); // Check for kernel launch error
		CudaErrorCheck(cudaDeviceSynchronize()); // Check for kernel execution error

		// Copy results to host
		CudaErrorCheck(cudaMemcpy(lower_bounds_, d_lower_bounds_, sizeof(float) * num_concurrent, cudaMemcpyDeviceToHost));
		CudaErrorCheck(cudaMemcpy(upper_bounds_, d_upper_bounds_, sizeof(float) * num_concurrent, cudaMemcpyDeviceToHost));

		// For the first iteration (#0), choose whether to run GMA for the new best-so-far node (standard)
		// or for every subcube in the first subdivision (batch)
		if (num_iterations > 0 || initialisation_type_ == 0) {
			// Update optimal error
			int optimal_node_index = -1;
			for (int i = 0; i < num_concurrent; ++i) {
				if (upper_bounds_[i] < optimal_value_) {
					optimal_value_ = upper_bounds_[i];
					optimal_node_index = i;
				}
			}
			// Run GMA for best node, if better than optimal_value
			if (optimal_node_index >= 0) {
				Node node;
				GetChild(node_parent, optimal_node_index, bits_per_dim, nodes_per_dim, node);
				optimal_node_ = node;

				// Find subcube centre
				float tx_centre = optimal_node_.tx + optimal_node_.tw / 2;
				float ty_centre = optimal_node_.ty + optimal_node_.tw / 2;
				float tz_centre = optimal_node_.tz + optimal_node_.tw / 2;
				float rx_centre = optimal_node_.rx + optimal_node_.rw / 2;
				float ry_centre = optimal_node_.ry + optimal_node_.rw / 2;
				float rz_centre = optimal_node_.rz + optimal_node_.rw / 2;

				// Find transformation parameters corresponding to the subcube centre
				// Convert angle-axis rotation into a quaternion
				float t = sqrt(rx_centre * rx_centre + ry_centre * ry_centre + rz_centre * rz_centre);
				if (t > 0.0) {
					float v[3] = {rx_centre / t,
												ry_centre / t,
												rz_centre / t};
					theta[0] = v[0] * sin(t / 2);
					theta[1] = v[1] * sin(t / 2);
					theta[2] = v[2] * sin(t / 2);
					theta[3] = cos(t / 2);
				} else {
					theta[0] = 0;
					theta[1] = 0;
					theta[2] = 0;
					theta[3] = 1;
				}
				theta[4] = tx_centre;
				theta[5] = ty_centre;
				theta[6] = tz_centre;

				// Run GMA
				StartOptimisation(solver, theta, initial_value, optimal_value_gma); // Updates theta_ if a better value is found
				if (optimal_value_gma < optimal_value_) {
					optimal_value_ = optimal_value_gma;
				}

				// Exit loop early if optimality conditions are met
				if ((optimal_value_ - node_parent.lb) <= epsilon_) {
					std::cout << "Optimal Value: " << std::setprecision(6) << optimal_value_ << ", Lower Bound: " << std::setprecision(6) << node_parent.lb << std::endl;
					optimality_certificate_ = 1;
					break;
				}
			}
		} else {
			for (int i = 0; i < num_concurrent; ++i) {
				Node node;
				GetChild(node_parent, i, bits_per_dim, nodes_per_dim, node);

				// Find subcube centre
				float tx_centre = node.tx + node.tw / 2;
				float ty_centre = node.ty + node.tw / 2;
				float tz_centre = node.tz + node.tw / 2;
				float rx_centre = node.rx + node.rw / 2;
				float ry_centre = node.ry + node.rw / 2;
				float rz_centre = node.rz + node.rw / 2;

				// Find transformation parameters corresponding to the subcube centre
				// Convert angle-axis rotation into a quaternion
				float t = sqrt(rx_centre * rx_centre + ry_centre * ry_centre + rz_centre * rz_centre);
				if (t > 0.0) {
					float v[3] = {rx_centre / t,
												ry_centre / t,
												rz_centre / t};
					theta[0] = v[0] * sin(t / 2);
					theta[1] = v[1] * sin(t / 2);
					theta[2] = v[2] * sin(t / 2);
					theta[3] = cos(t / 2);
				} else {
					theta[0] = 0;
					theta[1] = 0;
					theta[2] = 0;
					theta[3] = 1;
				}
				theta[4] = tx_centre;
				theta[5] = ty_centre;
				theta[6] = tz_centre;

				// Run GMA
				StartOptimisation(solver, theta, initial_value, optimal_value_gma); // Updates theta_ if a better value is found
				if (optimal_value_gma < optimal_value_) {
					optimal_value_ = optimal_value_gma;
					optimal_node_ = node;
				}

				// Exit loop early if optimality conditions are met
				if ((optimal_value_ - node_parent.lb) <= epsilon_) {
					std::cout << "Optimal Value: " << std::setprecision(6) << optimal_value_ << ", Lower Bound: " << std::setprecision(6) << node_parent.lb << std::endl;
					optimality_certificate_ = 1;
					break;
				}
			}
		}

		// Add nodes to queue if lb < optimal_value_
		for (int i = 0; i < num_concurrent; ++i) {
			if (lower_bounds_[i] < optimal_value_) {
				// Must be within rotation sphere
				if (lower_bounds_[i] > -INF) {
					Node node;
					GetChild(node_parent, i, bits_per_dim, nodes_per_dim, node);
					node.lb = lower_bounds_[i];
					node.ub = upper_bounds_[i];
					queue.push(node);
				}
			}
		}
		num_iterations++;
	}

	// Update rotation_ and translation_
	float quaternion[4];
	std::copy(theta_, theta_ + 4, quaternion);
	QuaternionToRotation(quaternion, rotation_);
	std::copy(theta_ + 4, theta_ + 7, translation_);
}

void GOGMA::GetChild(Node& node_parent, int index, int bits_per_dim, int nodes_per_dim, Node& node) {
	int mask = nodes_per_dim - 1;
	node.tw = node_parent.tw / nodes_per_dim;
	node.tx = node_parent.tx + node.tw * (index & mask);
	node.ty = node_parent.ty + node.tw * ((index >> (1 * bits_per_dim)) & mask);
	node.tz = node_parent.tz + node.tw * ((index >> (2 * bits_per_dim)) & mask);
	node.rw = node_parent.rw / nodes_per_dim;
	node.rx = node_parent.rx + node.rw * ((index >> (3 * bits_per_dim)) & mask);
	node.ry = node_parent.ry + node.rw * ((index >> (4 * bits_per_dim)) & mask);
	node.rz = node_parent.rz + node.rw * ((index >> (5 * bits_per_dim)) & mask);
}

void GOGMA::Clear() {
	// Free host memory
	delete[] lower_bounds_;
	delete[] upper_bounds_;
	delete[] mu_x_;
	delete[] mu_y_;
	delete[] mu_y_transformed_;
	delete[] phi_x_;
	delete[] phi_y_;

	// Free device memory
	CudaErrorCheck(cudaFree(d_mu_x_));
	CudaErrorCheck(cudaFree(d_mu_y_));
	CudaErrorCheck(cudaFree(d_phi_x_));
	CudaErrorCheck(cudaFree(d_phi_y_));
	CudaErrorCheck(cudaFree(d_lower_bounds_));
	CudaErrorCheck(cudaFree(d_upper_bounds_));
}

void GOGMA::TransformMuY(const double* theta) {
	float quaternion[4];
	std::copy(theta, theta + 4, quaternion);
	QuaternionToRotation(quaternion, rotation_);
	std::copy(theta + 4, theta + 7, translation_);

	// mu_y_transformed = mu_y * rotation^T + translation
	// rotation_ is stored row-major
	for (int i = 0; i < num_components_y_; ++i) {
		for (int j = 0; j < dimension_; ++j) {
			int ij = i * dimension_ + j;
			mu_y_transformed_[ij] = translation_[j];
			for (int k = 0; k < dimension_; ++k) {
				mu_y_transformed_[ij] += mu_y_[i * dimension_ + k] * rotation_[j * dimension_ + k];
			}
		}
	}
}

void GOGMA::QuaternionToRotation(const float* q, float* R) {
	// Using: http://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Conversion_to_and_from_the_matrix_representation
	// q = a + bi + cj + dk
	float a = q[3];
	float b = q[0];
	float c = q[1];
	float d = q[2];
	float a2 = a * a;
	float b2 = b * b;
	float c2 = c * c;
	float d2 = d * d;
	float bc = b * c;
	float cd = c * d;
	float bd = b * d;
	float ab = a * b;
	float ac = a * c;
	float ad = a * d;

	// Diagonal terms
	R[0] = a2 + b2 - c2 - d2;
	R[4] = a2 - b2 + c2 - d2;
	R[8] = a2 - b2 - c2 + d2;

	// Off-diagonal terms
	R[1] = 2 * (bc - ad);
	R[2] = 2 * (bd + ac);
	R[3] = 2 * (bc + ad);
	R[5] = 2 * (cd - ab);
	R[6] = 2 * (bd - ac);
	R[7] = 2 * (cd + ab);

	// Normalise
	float z = a2 + b2 + c2 + d2;
	for (int i = 0; i < 9; ++i)	{
		R[i] = R[i] / z;
	}
}

void GOGMA::QuaternionToRotation(const double* q, double* R, double* dR1, double* dR2, double* dR3, double* dR4) {
	// Using: http://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Conversion_to_and_from_the_matrix_representation
	// q = a + bi + cj + dk
	double a = q[3];
	double b = q[0];
	double c = q[1];
	double d = q[2];
	double a2 = a * a;
	double b2 = b * b;
	double c2 = c * c;
	double d2 = d * d;
	double bc = b * c;
	double cd = c * d;
	double bd = b * d;
	double ab = a * b;
	double ac = a * c;
	double ad = a * d;

	// Diagonal terms
	R[0] = a2 + b2 - c2 - d2;
	R[4] = a2 - b2 + c2 - d2;
	R[8] = a2 - b2 - c2 + d2;

	// Off-diagonal terms
	R[1] = 2 * (bc - ad);
	R[2] = 2 * (bd + ac);
	R[3] = 2 * (bc + ad);
	R[5] = 2 * (cd - ab);
	R[6] = 2 * (bd - ac);
	R[7] = 2 * (cd + ab);

	// Normalise
	double z = a2 + b2 + c2 + d2;
	double z2 = z * z;
	for (int i = 0; i < 9; ++i) {
		R[i] = R[i] / z;
	}

	// Find the derivative of the rotation matrix with respect to each quaternion parameter
  // Diagonal terms
  // Derivative of R(0,0) = (a2 + b2 - c2 - d2) / (a2 + b2 + c2 + d2)
  dR1[0] = +4 * b * (c2 + d2) / z2;
  dR2[0] = -4 * c * (b2 + a2) / z2;
  dR3[0] = -4 * d * (b2 + a2) / z2;
  dR4[0] = +4 * a * (c2 + d2) / z2;
  // Derivative of R(1,1) = (a2 - b2 + c2 - d2) / (a2 + b2 + c2 + d2)
  dR1[4] = -4 * b * (c2 + a2) / z2;
  dR2[4] = +4 * c * (b2 + d2) / z2;
  dR3[4] = -4 * d * (c2 + a2) / z2;
  dR4[4] = +4 * a * (b2 + d2) / z2;
  // Derivative of R(2,2) = (a2 - b2 - c2 + d2) / (a2 + b2 + c2 + d2)
  dR1[8] = -4 * b * (d2 + a2) / z2;
  dR2[8] = -4 * c * (a2 + d2) / z2;
  dR3[8] = +4 * d * (b2 + c2) / z2;
  dR4[8] = +4 * a * (b2 + c2) / z2;

  // Off-diagonal terms
  // Derivative of R(0,1) = (2 * (bc - ad)) / (a2 + b2 + c2 + d2)
  dR1[1] = +2 * c / z - 2 * b * R[1] / z2;
  dR2[1] = +2 * b / z - 2 * c * R[1] / z2;
  dR3[1] = -2 * a / z - 2 * d * R[1] / z2;
  dR4[1] = -2 * d / z - 2 * a * R[1] / z2;
  // Derivative of R(0,2) = (2 * (bd + ac)) / (a2 + b2 + c2 + d2)
  dR1[2] = +2 * d / z - 2 * b * R[2] / z2;
  dR2[2] = +2 * a / z - 2 * c * R[2] / z2;
  dR3[2] = +2 * b / z - 2 * d * R[2] / z2;
  dR4[2] = +2 * c / z - 2 * a * R[2] / z2;
  // Derivative of R(1,0) = (2 * (bc + ad)) / (a2 + b2 + c2 + d2)
  dR1[3] = +2 * c / z - 2 * b * R[3] / z2;
  dR2[3] = +2 * b / z - 2 * c * R[3] / z2;
  dR3[3] = +2 * a / z - 2 * d * R[3] / z2;
  dR4[3] = +2 * d / z - 2 * a * R[3] / z2;
  // Derivative of R(1,2) = (2 * (cd - ab)) / (a2 + b2 + c2 + d2)
  dR1[5] = -2 * a / z - 2 * b * R[5] / z2;
  dR2[5] = +2 * d / z - 2 * c * R[5] / z2;
  dR3[5] = +2 * c / z - 2 * d * R[5] / z2;
  dR4[5] = -2 * b / z - 2 * a * R[5] / z2;
  // Derivative of R(2,0) = (2 * (bd - ac)) / (a2 + b2 + c2 + d2)
  dR1[6] = +2 * d / z - 2 * b * R[6] / z2;
  dR2[6] = -2 * a / z - 2 * c * R[6] / z2;
  dR3[6] = +2 * b / z - 2 * d * R[6] / z2;
  dR4[6] = -2 * c / z - 2 * a * R[6] / z2;
  // Derivative of R(2,1) = (2 * (cd + ab)) / (a2 + b2 + c2 + d2)
  dR1[7] = +2 * a / z - 2 * b * R[7] / z2;
  dR2[7] = +2 * d / z - 2 * c * R[7] / z2;
  dR3[7] = +2 * c / z - 2 * d * R[7] / z2;
  dR4[7] = +2 * b / z - 2 * a * R[7] / z2;
}

double GOGMA::GaussTransform() {
	return GaussTransform(mu_x_, mu_y_transformed_, phi_x_, phi_y_, sigma_x_, sigma_y_);
}

double GOGMA::GaussTransform(const float* x, const float* y, const float* phi_x, const float* phi_y, const float sigma_x, const float sigma_y) {
	double fval = 0;
	// For all points x in X
	for (int i = 0; i < num_components_x_; ++i) {
		// For all points y in Y
		for (int j = 0; j < num_components_y_; ++j) {
			// Calculate the squared distance between points x and y
			double distance2_ij = 0;
			for (int k = 0; k < dimension_; ++k) {
				int ik = i * dimension_ + k;
				int jk = j * dimension_ + k;
				distance2_ij += (x[ik] - y[jk]) * (x[ik] - y[jk]);
			}
			// Evaluate the pointwise summand of the objective function
			fval += phi_x[i] * phi_y[j] * exp(-distance2_ij / (2 * (sigma_x * sigma_x + sigma_y * sigma_y)));
		}
	}
	return fval;
}

double GOGMA::GaussTransform(float sigma_x, float sigma_y, double* gradient) {
	return GaussTransform(mu_x_, mu_y_transformed_, phi_x_, phi_y_, sigma_x, sigma_y, gradient);
}

double GOGMA::GaussTransform(double* gradient) {
	return GaussTransform(mu_x_, mu_y_transformed_, phi_x_, phi_y_, sigma_x_, sigma_y_, gradient);
}

double GOGMA::GaussTransform(const float* x, const float* y, const float* phi_x, const float* phi_y, const float sigma_x, const float sigma_y, double* gradient) {
	double fval = 0;
	for (int i = 0; i < num_components_y_ * dimension_; ++i) gradient[i] = 0;
	// For all points x in X
	for (int i = 0; i < num_components_x_; ++i) {
		// For all points y in Y
		for (int j = 0; j < num_components_y_; ++j) {
			// Calculate the squared distance between points x and y
			double distance2_ij = 0;
			for (int k = 0; k < dimension_; ++k) {
				int ik = i * dimension_ + k;
				int jk = j * dimension_ + k;
				distance2_ij += (x[ik] - y[jk]) * (x[ik] - y[jk]);
			}
			// Evaluate the pointwise summand of the objective function
			double cost_ij = phi_x[i] * phi_y[j] * exp(-distance2_ij / (2 * (sigma_x * sigma_x + sigma_y * sigma_y)));
			// Evaluate the gradient
			for (int k = 0; k < dimension_; ++k) {
				int ik = i * dimension_ + k;
				int jk = j * dimension_ + k;
				gradient[jk] += -cost_ij * (y[jk] - x[ik]);
			}
			fval += cost_ij;
		}
	}
	for (int i = 0; i < (num_components_y_ * dimension_); ++i) {
		gradient[i] /= (sigma_x * sigma_x + sigma_y * sigma_y);
	}
	return fval;
}

void GOGMA::InitialiseOptimiser(vnl_lbfgsb& solver) {
	vnl_vector<long> optimiser_num_bounds;
	vnl_vector<double> optimiser_lower_bounds;
	vnl_vector<double> optimiser_upper_bounds;
	SetOptimiserBounds(optimiser_num_bounds, optimiser_lower_bounds, optimiser_upper_bounds);
	solver.set_bound_selection(optimiser_num_bounds);
	solver.set_lower_bound(optimiser_lower_bounds);
	solver.set_upper_bound(optimiser_upper_bounds);
	SetOptimisationOptions(solver);
	cost_function_->Initialise(this);
}

void GOGMA::SetOptimiserBounds(vnl_vector<long>& optimiser_num_bounds, vnl_vector<double>& optimiser_lower_bounds, vnl_vector<double>& optimiser_upper_bounds) {
	if (dimension_ == 2) {
		optimiser_num_bounds.set_size(3); // [x y alpha]
		optimiser_num_bounds.fill(0); // unconstrained
		optimiser_num_bounds[2] = 2; // constrained with lower and upper bounds
		optimiser_lower_bounds.set_size(3);
		optimiser_lower_bounds.fill(0);
		optimiser_lower_bounds[2] = -M_PI;
		optimiser_upper_bounds.set_size(3);
		optimiser_upper_bounds.fill(0);
		optimiser_upper_bounds[2] = -M_PI;
	} else if (dimension_ == 3) {
		optimiser_num_bounds.set_size(7); // [qx qy qz qw tx ty tz]
		optimiser_num_bounds.fill(0); // unconstrained
		for (int i = 0; i < 4; ++i) {
			optimiser_num_bounds[i] = 2; // constrained with lower and upper bounds
		}
		optimiser_lower_bounds.set_size(7);
		optimiser_lower_bounds.fill(0);
		for (int i = 0; i < 4; ++i) {
			optimiser_lower_bounds[i] = -1;
		}
		optimiser_upper_bounds.set_size(7);
		optimiser_upper_bounds.fill(0);
		for (int i = 0; i < 4; ++i) {
			optimiser_upper_bounds[i] = 1;
		}
	}
}

void GOGMA::SetOptimisationOptions(vnl_lbfgsb& solver) {
	// Set the convergence tolerance on F (sum of squared residuals).
	// When the differences in successive RMS errors is less than this, the
	// routine terminates. Default: 1e-9
	solver.set_f_tolerance(1e-9);
	// Set the convergence tolerance on X.
	// When the length of the steps taken in X are about this long, the routine
	// terminates. Default: 1e-8
	solver.set_x_tolerance(1e-8);
	// Set the convergence tolerance on Grad(F)' * F.
	// Default: 1e-5
	solver.set_g_tolerance(1e-5);
	// Set the termination maximum number of iterations.
	// Default: 10000
	solver.set_max_function_evals(10000);
	// Set the cost function convergence factor.
	// When an iteration changes the function value by an amount smaller than
	// this factor times the machine epsilon (scaled by function magnitude)
	// convergence is assumed. Default: 1e+7
	solver.set_cost_function_convergence_factor(1e3);
	// Set the projected gradient tolerance.
	// When the projected gradient vector has no component larger than
	// the given value convergence is assumed. Default: 1e-5
	solver.set_projected_gradient_tolerance(1e-5);
}

void GOGMA::StartOptimisation(vnl_lbfgsb& solver, double initial_theta[], double& initial_value, double& final_value) {
	vnl_vector<double> theta;
	theta.set_size(7);
	for (int i = 0; i < 7; ++i)	{
		theta[i] = initial_theta[i];
	}
	if (!solver.minimize(theta)) {
		// If the solver failed, use initial theta to find cost function value
		TransformMuY(initial_theta);
		initial_value = -GaussTransform();
		final_value = initial_value;
		// If the cost function value has improved, copy the best parameters across
		if (final_value < optimal_value_) {
			for (int i = 0; i < 7; ++i) {
				theta_[i] = initial_theta[i];
			}
		}
	} else {
		initial_value = solver.get_start_error();
		final_value = solver.get_end_error();
		// If optimisation improved the cost function value, copy the best parameters across
		if ((final_value < initial_value) && (final_value < optimal_value_)) {
			for (int i = 0; i < 7; ++i) {
				theta_[i] = theta[i];
			}
		}
	}
}
