/*
 * Globally-Optimal Gaussian Mixture Alignment (GOGMA): gogma_api
 * A interface for Globally-Optimal Gaussian Mixture Alignment. Parses user
 * input, reads configuration file, loads point-sets, builds GMMs, runs GOGMA,
 * displays results, optionally refines the result with GMA.
 *
 * Campbell, D., and Petersson, L., "GOGMA: Globally-Optimal Gaussian
 * Mixture Alignment", IEEE Conference on Computer Visision and
 * Pattern Recognition (CVPR), Las Vegas, USA, IEEE, Jun. 2016
 *
 * For the full license, see license.txt in the root directory
 *
 * Author: Dylan Campbell
 * Date: 20160212
 * Revision: 1.1
 */

#include <time.h>
#include <iostream>
#include <fstream>
#include "gogma.h"
#include "svgm.h"
#include "config_map.h"

#define DEFAULT_POINTSET_X_FNAME "x.txt"
#define DEFAULT_POINTSET_Y_FNAME "y.txt"
#define DEFAULT_CONFIG_FNAME "config.txt"
#define DEFAULT_OUTPUT_FNAME "output.txt"

using std::cout;
using std::endl;
using std::string;
using std::ofstream;

void ParseInput(int argc, char **argv, string& pointset_x_filename, string& pointset_y_filename, string& config_filename, string& output_filename);
void ReadConfig(string filename, GOGMA& gogma);
void ReadConfig(string filename, SVGM& svgm_x, SVGM& svgm_y);
void ReadConfig(string filename, SVGM& svgm_x, SVGM& svgm_y, bool& do_refinement);

int main(int argc, char* argv[]) {
	// Parse user input
	string pointset_x_filename, pointset_y_filename, config_filename, output_filename;
	ParseInput(argc, argv, pointset_x_filename, pointset_y_filename, config_filename, output_filename);

	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
	// Load and construct Gaussian mixtures
	// Here we use SVGMs; alternatively KDE-GMMs or EM-GMMs can be used
	SVGM svgm_x, svgm_y;

	ReadConfig(config_filename, svgm_x, svgm_y);
	int num_points_x = svgm_x.LoadPointSet(pointset_x_filename);
	int num_points_y = svgm_y.LoadPointSet(pointset_y_filename);
	if (svgm_x.problem().x[0].dim != svgm_y.problem().x[0].dim) {
		std::cout << "Data point-set has a different dimension to the model point-set" << std::endl;
		return -1;
	}

	// Centre the point-sets with respect to their axis-aligned bounding boxes and scale both to fit within [-1 1]^3
	double centre_x[3], centre_y[3];
	double widths_x[3], widths_y[3];
	double negative_centre_x[3], negative_centre_y[3];
	svgm_x.GetBoundingBox(centre_x, widths_x);
	svgm_y.GetBoundingBox(centre_y, widths_y);
	for (int i = 0; i < svgm_x.problem().x[0].dim; ++i) negative_centre_x[i] = -centre_x[i];
	svgm_x.TranslatePointSet(negative_centre_x);
	for (int i = 0; i < svgm_y.problem().x[0].dim; ++i) negative_centre_y[i] = -centre_y[i];
	svgm_y.TranslatePointSet(negative_centre_y);
	double scale_factor = 0.0;
	for (int i = 0; i < svgm_x.problem().x[0].dim; ++i) {
		if (scale_factor < widths_x[i]) scale_factor = widths_x[i];
	}
	for (int i = 0; i < svgm_y.problem().x[0].dim; ++i) {
		if (scale_factor < widths_y[i]) scale_factor = widths_y[i];
	}
	scale_factor /= 2.0;
	svgm_x.ScalePointSet(1.0 / scale_factor);
	svgm_y.ScalePointSet(1.0 / scale_factor);
	// Ensure sigma is consistent with the new scale
	svgm_x.set_sigma(svgm_x.sigma() / scale_factor);
	svgm_y.set_sigma(svgm_y.sigma() / scale_factor);

	// Train SVMs to build the SVGMs
	clock_t  clock_begin = clock();
	svgm_x.BuildSVGM();
	svgm_y.BuildSVGM();
	double duration_svgm = (double)(clock() - clock_begin) / CLOCKS_PER_SEC;

	// Inflate sigma by sigma_factor (decouples SVM from GMM)
	svgm_x.set_sigma(svgm_x.sigma() * svgm_x.sigma_factor());
	svgm_y.set_sigma(svgm_y.sigma() * svgm_y.sigma_factor());

	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
	// Setup and run GOGMA
	GOGMA gogma;
	ReadConfig(config_filename, gogma);

	// Pass GMMs to GOGMA
	gogma.set_dimension(svgm_x.dimension());
	gogma.set_num_components_x(svgm_x.num_components());
	gogma.set_num_components_y(svgm_y.num_components());
	gogma.set_mu_x(svgm_x.mu().data_block());
	gogma.set_mu_y(svgm_y.mu().data_block());
	gogma.set_phi_x(svgm_x.phi().data_block());
	gogma.set_phi_y(svgm_y.phi().data_block());
	gogma.set_sigma_x(svgm_x.sigma());
	gogma.set_sigma_y(svgm_y.sigma());

  // Set initial transformation parameters to the null transform
	double initial_theta[7] = {0};
	initial_theta[3] = 1.0;
	gogma.set_initial_theta(initial_theta);

	// Run GOGMA
	cout << "Point-Set X: " << pointset_x_filename << " (" << num_points_x << " points, " << svgm_x.num_components() << " Gaussians), Point-Set Y: " << pointset_y_filename << " (" << num_points_y << " points, " << svgm_y.num_components() << " Gaussians)" << endl;
	clock_begin = clock();
	gogma.Run();
	double duration_gogma = (double)(clock() - clock_begin) / CLOCKS_PER_SEC;

	// Outputs
	double translation[3];
	double rotation[3][3];
	for (int i = 0; i < 3; ++i) translation[i] = gogma.translation()[i];
	for (int i = 0; i < 3; ++i) {
		for (int j = 0; j < 3; ++j) {
			rotation[i][j] = gogma.rotation()[i * 3 + j];
		}
	}

	// Remove centring and scaling from the translation vector
	double rotated_centre_y[3];
	rotated_centre_y[0] = rotation[0][0] * centre_y[0] + rotation[0][1] * centre_y[1] + rotation[0][2] * centre_y[2];
	rotated_centre_y[1] = rotation[1][0] * centre_y[0] + rotation[1][1] * centre_y[1] + rotation[1][2] * centre_y[2];
	rotated_centre_y[2] = rotation[2][0] * centre_y[0] + rotation[2][1] * centre_y[1] + rotation[2][2] * centre_y[2];
	for (int i = 0; i < 3; ++i) translation[i] = scale_factor * translation[i] + centre_x[i] - rotated_centre_y[i];

	// Print results
	cout << "Rotation Matrix:" << endl;
	cout << rotation[0][0] << " " << rotation[0][1] << " " << rotation[0][2] << endl;
	cout << rotation[1][0] << " " << rotation[1][1] << " " << rotation[1][2] << endl;
	cout << rotation[2][0] << " " << rotation[2][1] << " " << rotation[2][2] << endl;
	cout << "Translation Vector:" << endl;
	cout << translation[0] << " " << translation[1] << " " << translation[2] << endl;
	cout << "Duration: " << duration_svgm << "s (SVGM), " << duration_gogma << "s (GOGMA)" << endl;
	cout << "Optimality Certificate: " << gogma.optimality_certificate() << endl;

	// Save results
	ofstream ofile(output_filename.c_str(), ofstream::out | ofstream::app);
	if (ofile.is_open()) {
		ofile << translation[0] << " " << translation[1] << " " << translation[2] << " " << rotation[0][0] << " " << rotation[0][1] << " " << rotation[0][2] << " " << rotation[1][0] << " " << rotation[1][1] << " " << rotation[1][2] << " " << rotation[2][0] << " " << rotation[2][1] << " " << rotation[2][2] << endl;
		ofile.close();
	} else {
		cout << "Cannot open output file" << endl;
		return -1;
	}
	
	/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
	// Setup and run refinement
	// Construct Gaussian Mixtures with finer resolution
	bool do_refinement = false;
	ReadConfig(config_filename, svgm_x, svgm_y, do_refinement);

	if (do_refinement) {
		// Ensure sigma is consistent with the new scale
		svgm_x.set_sigma(svgm_x.sigma() / scale_factor);
		svgm_y.set_sigma(svgm_y.sigma() / scale_factor);

		clock_begin = clock();
		svgm_x.BuildSVGM();
		svgm_y.BuildSVGM();
		duration_svgm = (double)(clock() - clock_begin) / CLOCKS_PER_SEC;

		// Inflate sigma by sigma_factor (decouples SVM from GMM)
		svgm_x.set_sigma(svgm_x.sigma() * svgm_x.sigma_factor());
		svgm_y.set_sigma(svgm_y.sigma() * svgm_y.sigma_factor());

		// Pass GMMs to GOGMA
		gogma.set_dimension(svgm_x.dimension());
		gogma.set_num_components_x(svgm_x.num_components());
		gogma.set_num_components_y(svgm_y.num_components());
		gogma.set_mu_x(svgm_x.mu().data_block());
		gogma.set_mu_y(svgm_y.mu().data_block());
		gogma.set_phi_x(svgm_x.phi().data_block());
		gogma.set_phi_y(svgm_y.phi().data_block());
		gogma.set_sigma_x(svgm_x.sigma());
		gogma.set_sigma_y(svgm_y.sigma());
		gogma.set_epsilon(1.0); // Use a large epsilon to force immediate termination (after initial GMA)

		// Use theta found by GOGMA as initial parameters for GMA refinement
		gogma.set_initial_theta(gogma.theta()); // theta is still in the centred and scaled coordinate system

		// Run GMA (via GOGMA)
		cout << "Point-Set X: " << pointset_x_filename << " (" << num_points_x << " points, " << svgm_x.num_components() << " Gaussians), Point-Set Y: " << pointset_y_filename << " (" << num_points_y << " points, " << svgm_y.num_components() << " Gaussians)" << endl;
		clock_begin = clock();
		gogma.Run();
		duration_gogma = (double)(clock() - clock_begin) / CLOCKS_PER_SEC;

		// Outputs
		for (int i = 0; i < 3; ++i) translation[i] = gogma.translation()[i];
		for (int i = 0; i < 3; ++i) {
			for (int j = 0; j < 3; ++j) {
				rotation[i][j] = gogma.rotation()[i * 3 + j];
			}
		}
		gogma.set_optimality_certificate(0); // Refinement cannot guarantee optimality

		// Return to original scale and remove rotation perturbation
		rotated_centre_y[0] = rotation[0][0] * centre_y[0] + rotation[0][1] * centre_y[1] + rotation[0][2] * centre_y[2];
		rotated_centre_y[1] = rotation[1][0] * centre_y[0] + rotation[1][1] * centre_y[1] + rotation[1][2] * centre_y[2];
		rotated_centre_y[2] = rotation[2][0] * centre_y[0] + rotation[2][1] * centre_y[1] + rotation[2][2] * centre_y[2];
		for (int i = 0; i < 3; ++i) translation[i] = scale_factor * translation[i] + centre_x[i] - rotated_centre_y[i];

		// Print results
		cout << "Rotation Matrix:" << endl;
		cout << rotation[0][0] << " " << rotation[0][1] << " " << rotation[0][2] << endl;
		cout << rotation[1][0] << " " << rotation[1][1] << " " << rotation[1][2] << endl;
		cout << rotation[2][0] << " " << rotation[2][1] << " " << rotation[2][2] << endl;
		cout << "Translation Vector:" << endl;
		cout << translation[0] << " " << translation[1] << " " << translation[2] << endl;
		cout << "Duration: " << duration_svgm << "s (SVGM), " << duration_gogma << "s (GOGMA)" << endl;
		cout << "Optimality Certificate: " << gogma.optimality_certificate() << endl;
		cout << endl;

		// Save results
		ofstream ofile_refinement((output_filename.substr(0, output_filename.size() - 4) + "_refinement.txt").c_str(), ofstream::out | ofstream::app);
		if (ofile_refinement.is_open()) {
			ofile_refinement << translation[0] << " " << translation[1] << " " << translation[2] << " " << rotation[0][0] << " " << rotation[0][1] << " " << rotation[0][2] << " " << rotation[1][0] << " " << rotation[1][1] << " " << rotation[1][2] << " " << rotation[2][0] << " " << rotation[2][1] << " " << rotation[2][2] << endl;
			ofile_refinement.close();
		} else {
			cout << "Cannot open output file" << endl;
			return -1;
		}
	}
	return 0;
}

void ParseInput(int argc, char **argv, string & pointset_x_filename, string & pointset_y_filename, string & config_filename, string & output_filename) {
	// Set default values
	pointset_x_filename = DEFAULT_POINTSET_X_FNAME;
	pointset_y_filename = DEFAULT_POINTSET_Y_FNAME;
	config_filename = DEFAULT_CONFIG_FNAME;
	output_filename = DEFAULT_OUTPUT_FNAME;

	// Parse input
	if (argc > 4) output_filename = argv[4];
	if (argc > 3) config_filename = argv[3];
	if (argc > 2) pointset_y_filename = argv[2];
	if (argc > 1) pointset_x_filename = argv[1];

	// Print to screen
	for (int i = 0; i < argc; ++i) cout << argv[i] << " ";
	cout << endl;
	cout << "Globally-Optimal Gaussian Mixture Alignment (GOGMA)" << endl;
	cout << "Copyright (C) 2015 Dylan Campbell & Lars Petersson" << endl;
	cout << "USAGE:" << argv[0] << " <POINT-SET X FILENAME> <POINT-SET Y FILENAME> <CONFIG FILENAME> <OUTPUT FILENAME>" << endl;
	cout << "OUTPUT: [translation vector | rotation matrix] transforming point-set Y to align with point-set X" << endl;
	cout << endl;
	cout << "INPUT:" << endl;
	cout << "(pointset_x_filename)->(" << pointset_x_filename << ")" << endl;
	cout << "(pointset_y_filename)->(" << pointset_y_filename << ")" << endl;
	cout << "(config_filename)->(" << config_filename << ")" << endl;
	cout << "(output_filename)->(" << output_filename << ")" << endl;
	cout << endl;
}

void ReadConfig(string filename, GOGMA& gogma) {
	// Open and parse the associated config file
	ConfigMap config(filename.c_str());

	float translation_minimum = config.getF("translation_minimum");
	float translation_maximum = config.getF("translation_maximum");
	float translation_width = translation_maximum - translation_minimum;

	gogma.set_epsilon(config.getF("epsilon"));
	gogma.set_translation_minimum(translation_minimum);
	gogma.set_translation_width(translation_width);
	gogma.set_initialisation_type(config.getI("initialisation_type"));

	cout << "CONFIG:" << endl;
	config.print();
	cout << endl;
}

void ReadConfig(string filename, SVGM& svgm_x, SVGM& svgm_y) {
	// Open and parse the associated config file
	ConfigMap config(filename.c_str());

	int num_elements_sigma = 0;
	int num_elements_nu = 0;
	int num_elements_sigma_factor = 0;
	int num_elements_max_num_points = 0;
	int num_elements_min_num_components = 0;
	double* sigma = config.getFArray("sigma", num_elements_sigma);
	double* nu = config.getFArray("nu", num_elements_nu);
	double* sigma_factor = config.getFArray("sigma_factor", num_elements_sigma_factor);
	int* max_num_points = config.getIArray("max_num_points", num_elements_max_num_points);
	int* min_num_components = config.getIArray("min_num_components", num_elements_min_num_components);

	svgm_x.set_sigma(sigma[0]);
	svgm_x.set_nu(nu[0]);
	svgm_x.set_sigma_factor(sigma_factor[0]);
	svgm_x.set_max_num_points(max_num_points[0]);
	svgm_x.set_min_num_components(min_num_components[0]);

	(num_elements_sigma > 1) ? svgm_y.set_sigma(sigma[1]) : svgm_y.set_sigma(sigma[0]);
	(num_elements_nu > 1) ? svgm_y.set_nu(nu[1]) : svgm_y.set_nu(nu[0]);
	(num_elements_sigma_factor > 1) ? svgm_y.set_sigma_factor(sigma_factor[1]) : svgm_y.set_sigma_factor(sigma_factor[0]);
	(num_elements_max_num_points > 1) ? svgm_y.set_max_num_points(max_num_points[1]) : svgm_y.set_max_num_points(max_num_points[0]);
	(num_elements_min_num_components > 1) ? svgm_y.set_min_num_components(min_num_components[1]) : svgm_y.set_min_num_components(min_num_components[0]);
}

void ReadConfig(string filename, SVGM& svgm_x, SVGM& svgm_y, bool& do_refinement) {
	// Open and parse the associated config file
	ConfigMap config(filename.c_str());

	if (config.getI("do_refinement") > 0) {
		do_refinement = true;

		int num_elements_sigma = 0;
		int num_elements_nu = 0;
		int num_elements_sigma_factor = 0;
		int num_elements_max_num_points = 0;
		int num_elements_min_num_components = 0;
		double* sigma = config.getFArray("sigma_refinement", num_elements_sigma);
		double* nu = config.getFArray("nu_refinement", num_elements_nu);
		double* sigma_factor = config.getFArray("sigma_factor_refinement", num_elements_sigma_factor);
		int* max_num_points = config.getIArray("max_num_points_refinement", num_elements_max_num_points);
		int* min_num_components = config.getIArray("min_num_components_refinement", num_elements_min_num_components);

		svgm_x.set_sigma(sigma[0]);
		svgm_x.set_nu(nu[0]);
		svgm_x.set_sigma_factor(sigma_factor[0]);
		svgm_x.set_max_num_points(max_num_points[0]);
		svgm_x.set_min_num_components(min_num_components[0]);

		(num_elements_sigma > 1) ? svgm_y.set_sigma(sigma[1]) : svgm_y.set_sigma(sigma[0]);
		(num_elements_nu > 1) ? svgm_y.set_nu(nu[1]) : svgm_y.set_nu(nu[0]);
		(num_elements_sigma_factor > 1) ? svgm_y.set_sigma_factor(sigma_factor[1]) : svgm_y.set_sigma_factor(sigma_factor[0]);
		(num_elements_max_num_points > 1) ? svgm_y.set_max_num_points(max_num_points[1]) : svgm_y.set_max_num_points(max_num_points[0]);
		(num_elements_min_num_components > 1) ? svgm_y.set_min_num_components(min_num_components[1]) : svgm_y.set_min_num_components(min_num_components[0]);
	} else {
		do_refinement = false;
	}
}
