/*
 * Globally-Optimal Gaussian Mixture Alignment (GOGMA): gogma
 * A class for Globally-Optimal Gaussian Mixture Alignment. Branches
 * a hyperoctree representation of the transformation space SE(3)
 * and bounds each branch on the GPU. Uses the VNL library for local
 * optimisation.
 *
 * Campbell, D., and Petersson, L., "GOGMA: Globally-Optimal Gaussian
 * Mixture Alignment", IEEE Conference on Computer Visision and
 * Pattern Recognition (CVPR), Las Vegas, USA, IEEE, Jun. 2016
 *
 * For the full license, see license.txt in the root directory
 *
 * Author: Dylan Campbell
 * Date: 20160212
 * Revision: 1.1
 */

#ifndef GOGMA_H_
#define GOGMA_H_

#include <iostream>
#include <iomanip>
#include <fstream>
#include <time.h>
#include <queue>
#include <vector>
#include <vnl/vnl_vector.h> // For optimiser
#include <vnl/algo/vnl_lbfgsb.h> // For optimiser
#include "l2_cost_function.h"

#define INF 1e10
#define PI 3.1415926536
#define SQRT3 1.732050808

class L2CostFunction; // Forward declare class L2CostFunction

// 6D Transformation Space Node
struct Node {
	float tx, ty, tz; // Minimum Translation Values of Node
	float rx, ry, rz; // Minimum Rotation Values of Node
	float tw; // Translation Width of Node
	float rw; // Rotation Width of Node
	float ub, lb; // Upper and Lower Bound of Node

	// The 'larger' node will come first in the queue
	friend bool operator < (const Node& node1, const Node& node2) {
		if (node1.lb != node2.lb) {
			return node1.lb > node2.lb; // Search lowest lower bound first
		} else {
			return node1.ub > node2.ub; // Search lowest upper bound first
		}
	}
};

class GOGMA {
public:
	// Constructors
	GOGMA();

	// Destructors
	~GOGMA();

	// Accessors
	int num_components_x() const {return num_components_x_;}
	int num_components_y() const {return num_components_y_;}
	int dimension() const {return dimension_;}
	int optimality_certificate() const {return optimality_certificate_;}
	int initialisation_type() const {return initialisation_type_;}
	float optimal_value() const {return optimal_value_;}
	float epsilon() const {return epsilon_;}
	float sigma_x() const {return sigma_x_;}
	float sigma_y() const {return sigma_y_;}
	float* mu_x() const {return mu_x_;}
	float* mu_y() const {return mu_y_;}
	float* phi_x() const {return phi_x_;}
	float* phi_y() const {return phi_y_;}
	float* translation() {return translation_;}
	float* rotation() {return rotation_;}
	double* theta() {return theta_;}

	// Mutators
	void set_num_components_x(int num_components_x) {num_components_x_ = num_components_x;}
	void set_num_components_y(int num_components_y) {num_components_y_ = num_components_y;}
	void set_dimension(int dimension) {dimension_ = dimension;}
	void set_optimality_certificate(int optimality_certificate) {optimality_certificate_ = optimality_certificate;}
	void set_initialisation_type(int initialisation_type) {initialisation_type_ = initialisation_type;}
	void set_optimal_value(float optimal_value) {optimal_value_ = optimal_value;}
	void set_epsilon(float epsilon) {epsilon_ = epsilon;}
	void set_sigma_x(double sigma_x) {sigma_x_ = static_cast<float>(sigma_x);}
	void set_sigma_y(double sigma_y) {sigma_y_ = static_cast<float>(sigma_y);}
	void set_mu_x(double* mu_x) {mu_x_ = new float[num_components_x_ * dimension_]; std::copy(mu_x, mu_x + num_components_x_ * dimension_, mu_x_);}
	void set_mu_y(double* mu_y) {mu_y_ = new float[num_components_y_ * dimension_]; std::copy(mu_y, mu_y + num_components_y_ * dimension_, mu_y_);}
	void set_phi_x(double* phi_x) {phi_x_ = new float[num_components_x_]; std::copy(phi_x, phi_x + num_components_x_, phi_x_);}
	void set_phi_y(double* phi_y) {phi_y_ = new float[num_components_y_]; std::copy(phi_y, phi_y + num_components_y_, phi_y_);}
	void set_initial_theta(double initial_theta[7]) {std::copy(initial_theta, initial_theta + 7, initial_theta_);}
	void set_translation_minimum(float translation_minimum) {initial_node_.tx = translation_minimum; initial_node_.ty = translation_minimum; initial_node_.tz = translation_minimum;}
	void set_translation_width(float translation_width) {initial_node_.tw = translation_width;}

	// Public Class Functions
	void Run();
	void TransformMuY(const double* theta);
	double GaussTransform(float sigma_x, float sigma_y, double* gradient);
	void QuaternionToRotation(const double* q, double* R, double* dR1, double* dR2, double* dR3, double* dR4);

private:
	int num_components_x_; // Number of components in GMM X
	int num_components_y_; // Number of components in GMM Y
	int dimension_; // Dimension of both GMMs
	int optimality_certificate_;
	int initialisation_type_;
	float epsilon_;
	float sigma_x_;
	float sigma_y_;
	float* mu_x_;
	float* mu_y_;
	float* mu_y_transformed_;
	float* phi_x_;
	float* phi_y_;
	float* lower_bounds_;
	float* upper_bounds_;
	float* d_mu_x_;
	float* d_mu_y_;
	float* d_mu_y_transformed_;
	float* d_phi_x_;
	float* d_phi_y_;
	float* d_lower_bounds_;
	float* d_upper_bounds_;
	float translation_[3];
	float rotation_[9];
	double optimal_value_;
	double theta_[7];
	double initial_theta_[7];
	Node initial_node_;
	Node optimal_node_;
	L2CostFunction* cost_function_;

	// Private Class Functions
	void Initialise();
	void BranchAndBound();
	void Clear();
	void GetChild(Node& node_parent, int index, int bits_per_dim, int nodes_per_dim, Node& node);
	void QuaternionToRotation(const float* q, float* R);
	double GaussTransform();
	double GaussTransform(double* gradient);
	double GaussTransform(const float* x, const float* y, const float* phi_x, const float* phi_y, const float sigma_x, const float sigma_y);
	double GaussTransform(const float* x, const float* y, const float* phi_x, const float* phi_y, const float sigma_x, const float sigma_y, double* gradient);
	void InitialiseOptimiser(vnl_lbfgsb& solver);
	void SetOptimiserBounds(vnl_vector<long>& optimiser_num_bounds, vnl_vector<double>& optimiser_lower_bounds, vnl_vector<double>& optimiser_upper_bounds);
	void SetOptimisationOptions(vnl_lbfgsb& solver);
	void StartOptimisation(vnl_lbfgsb& solver, double initial_theta[], double& initial_value, double& final_value);
};

#endif /* GOGMA_H_ */
