/*
 * Globally-Optimal Gaussian Mixture Alignment (GOGMA): svgm
 * A class for the Support Vector-parametrised Gaussian Mixture (SVGM)
 * data representation. Includes methods for loading point-sets, constructing
 * the SVGM and manipulation.
 *
 * Campbell, D., and Petersson, L., "GOGMA: Globally-Optimal Gaussian
 * Mixture Alignment", IEEE Conference on Computer Visision and
 * Pattern Recognition (CVPR), Las Vegas, USA, IEEE, Jun. 2016
 *
 * For the full license, see license.txt in the root directory
 *
 * Author: Dylan Campbell
 * Date: 20160212
 * Revision: 1.1
 */

#include "svgm.h"

void DisableSVMOutput(const char *s) {}

SVGM::SVGM() {
	num_points_ = -1;
	num_components_ = -1;
	max_num_points_ = 20000;
	min_num_components_ = 10;
	dimension_ = -1;
	sigma_ = 0.01;
	nu_ = 0.01;
	sigma_factor_ = 0.5;
	model_ = NULL;
	InitialiseSVMParameters();
}

SVGM::~SVGM() {
	delete[] problem_.y;
	for (int i = 0; i < problem_.l; ++i)
		delete[] (problem_.x + i)->values;
	delete[] problem_.x;
	svm_free_and_destroy_model(&model_);
}

/*
 * Public Class Functions
 */

int SVGM::LoadPointSet(std::string filename) {
	num_points_ = LoadPointSet(filename, &problem_);
	dimension_ = problem_.x[0].dim;
	return num_points_;
}

void SVGM::GetBoundingBox(double centre[3], double widths[3]) {
	double x, y, z;
	double min_x, min_y, min_z;
	double max_x, max_y, max_z;

	for (int i = 0; i < problem_.l; i++) {
		x = (problem_.x + i)->values[0];
		y = (problem_.x + i)->values[1];
		z = (problem_.x + i)->values[2];

		if (i == 0) {
			min_x = x;
			min_y = y;
			min_z = z;
			max_x = x;
			max_y = y;
			max_z = z;
		} else {
			if (min_x > x) min_x = x;
			if (min_y > y) min_y = y;
			if (min_z > z) min_z = z;
			if (max_x < x) max_x = x;
			if (max_y < y) max_y = y;
			if (max_z < z) max_z = z;
		}
	}
	centre[0] = (max_x + min_x) / 2;
	centre[1] = (max_y + min_y) / 2;
	centre[2] = (max_z + min_z) / 2;
	widths[0] = max_x - min_x;
	widths[1] = max_y - min_y;
	widths[2] = max_z - min_z;
}

void SVGM::TranslatePointSet(double translation[3]) {
	for (int i = 0; i < problem_.l; ++i) {
		for (int j = 0; j < problem_.x[i].dim; ++j) {
			problem_.x[i].values[j] += translation[j];
		}
	}
}

void SVGM::RotatePointSet(double rotation[3][3]) {
	double x, y, z;
	for (int i = 0; i < problem_.l; ++i) {
		x = problem_.x[i].values[0];
		y = problem_.x[i].values[1];
		z = problem_.x[i].values[2];
		problem_.x[i].values[0] = rotation[0][0] * x + rotation[0][1] * y + rotation[0][2] * z;
		problem_.x[i].values[1] = rotation[1][0] * x + rotation[1][1] * y + rotation[1][2] * z;
		problem_.x[i].values[2] = rotation[2][0] * x + rotation[2][1] * y + rotation[2][2] * z;
	}
}

void SVGM::ScalePointSet(double scale) {
	for (int i = 0; i < problem_.l; ++i) {
		for (int j = 0; j < problem_.x[i].dim; ++j) {
			problem_.x[i].values[j] *= scale;
		}
	}
}

void SVGM::BuildSVGM() {
	// Check that point-sets have been loaded
	if (num_points_ <= 0) {
		std::cout << "Load Point-Set Before Calling BuildSVGM()" << std::endl;
		return;
	}
	// Set SVM Parameters (gamma and nu)
	InitialiseSVM();
	// Train One-Class SVM
	if (!TrainSVM(&problem_, &model_)) return;
	// Map SVM to a GMM
	MapToGMM();
}

/*
 * Private Class Functions
 */

void SVGM::InitialiseSVMParameters() {
	svm_parameters_.svm_type = ONE_CLASS;
	svm_parameters_.kernel_type = RBF;
	svm_parameters_.gamma = 1000;
	svm_parameters_.nu = 0.01;
	svm_parameters_.cache_size = 1000;
	svm_parameters_.eps = 1e-3;
	svm_parameters_.shrinking = 0; // 1: Use shrinking heuristics
	svm_parameters_.probability = 0; // 1: Provide probability estimates
	svm_parameters_.degree = 0; // Degree of polynomial kernel
	svm_set_print_string_function(&DisableSVMOutput); // Comment out if output is desired
}

/*
 * Load whitespace-separated point-set from a file
 * Determines point-set size and dimension automatically
 */
int SVGM::LoadPointSet(std::string filename, struct svm_problem* problem) {
	int num_rows = -1;
	int num_cols = -1;
	int num_points = -1;
	int num_dimensions = -1;
	vnl_matrix<double> pointset;
	vnl_matrix<double> pointset_full;

	std::ifstream ifile(filename.c_str(), std::ios_base::in);
	if (ifile.is_open()) {
		if (pointset_full.read_ascii(ifile)) {
			num_rows = pointset_full.rows();
			num_cols = pointset_full.cols();
			// If user specifies to use all points, do not randomise the point-set
			if (max_num_points_ == 0 || num_rows <= max_num_points_) {
				num_points = num_rows;
				num_dimensions = num_cols;
				pointset = pointset_full;
			} else {
				// Generate randomised indices
				std::vector<int> random_indices;
				for(int i = 0; i < num_rows; ++i) random_indices.push_back(i);
				std::random_shuffle(random_indices.begin(), random_indices.end());
				num_points = std::min(num_rows, max_num_points_);
				num_dimensions = num_cols;
				// Fill in point-set
				pointset.set_size(num_points, num_dimensions);
				for(int i = 0; i < num_points; i++) {
					for(int j = 0; j < num_dimensions; ++j) {
						pointset[i][j] = pointset_full[random_indices[i]][j];
					}
				}
			}
			PrepareSVMProblem(pointset, num_points, num_dimensions, problem);
		} else {
			std::cout << "Cannot parse input file '" << filename << "'" << std::endl;
		}
		ifile.close();
	} else {
		std::cout << "Unable to open point file '" << filename << "'" << std::endl;
	}
	return num_points;
}

void SVGM::InitialiseSVM() {
	// If min_num_components_ was specified, set nu_ automatically
	if (min_num_components_ > 0 || nu_ == 0) {
		nu_ = static_cast<float>(min_num_components_) / static_cast<float>(num_points_);
	}
	// Ensure nu <= 1.0
	if (nu_ > 1.0) nu_ = 1.0;
	// Ensure nu >= 0
	if (nu_ < 0.0) nu_ = 0.0;
	svm_parameters_.nu = nu_;
	svm_parameters_.gamma = 1 / (2 * pow(sigma_, 2));
}

void SVGM::PrepareSVMProblem(vnl_matrix<double>& pointset, int num_points, int num_dimensions, struct svm_problem* problem) {
	// Copy to SVM problem
	problem->l = num_points;
	problem->y = new double[problem->l];
	problem->x = new struct svm_node[problem->l];

	for (int i = 0; i < problem->l; i++) {
		problem->y[i] = 1; // Class Label
		(problem->x + i)->dim = num_dimensions;
		(problem->x + i)->values = new double[num_dimensions];
		for (int j = 0; j < num_dimensions; ++j) {
			(problem->x + i)->values[j] = pointset[i][j];
		}
	}
}

bool SVGM::TrainSVM(const struct svm_problem* problem, struct svm_model** model) {
	bool ret = false;
	if(CheckSVMParameters(problem))	{
		(*model) = svm_train(problem, &svm_parameters_);
		ret = true;
	}
	return ret;
}

bool SVGM::CheckSVMParameters(const struct svm_problem* problem) {
	bool ret = true;
	const char* error_message = svm_check_parameter(problem, &svm_parameters_);
	if(error_message) {
		std::cout << "SVM ERROR: " << error_message << std::endl;
		ret = false;
	}
	return ret;
}

void SVGM::MapToGMM() {
	num_components_ = model_->l;
	double s = 0.0;
	mu_.set_size(num_components_, dimension_);
	phi_.set_size(num_components_);
	for(int i = 0; i < num_components_; ++i) {
		for(int j = 0; j < dimension_; ++j) {
			mu_[i][j] = model_->SV[i].values[j];
		}
		phi_[i] = model_->sv_coef[0][i];
		s += model_->sv_coef[0][i];
	}
	// Normalise phi
	// double t = pow(2 * M_PI * (sigma_^2 + sigma_y_^2), dimension_ / 2);
	for(int i = 0; i < num_components_; ++i) {
		phi_[i] /= s;
	}
}
