#ifndef _FFT_OT_H_
#define _FFT_OT_H_

#include <opencv2\opencv.hpp>  
#include <opencv2\core\core.hpp>
#include <opencv2\core\mat.hpp>
#include <iostream>  
#include "FiniteDifferenceOperator.h"
#include "PoissonDCT.h"
#include "Bezier.h"

using namespace std;
using namespace cv;

/*******************************************************************************
*
*	The implementation of the FFT OT Algorithm
*
*/

class CFFT_OT{

public:
	/* The constructor
	*  width, height - image dimensions
	*/
	CFFT_OT(int width, int height) : 
		m_solver(width, height),
		m_phi(width, height, CV_64F, Scalar(0)), 
		m_rhs(width, height, CV_64F, Scalar(0)),
		m_F(width,   height, CV_64F, Scalar(0)),
		m_Dxx(width, height, CV_64F, Scalar(0)),
		m_Dyy(width, height, CV_64F, Scalar(0)),
		m_Dxy(width, height, CV_64F, Scalar(0)),
		m_Dx(width, height, CV_64F, Scalar(0)),
		m_Dy(width, height, CV_64F, Scalar(0)),
		m_gradient( width, height, CV_64FC2, Scalar(0,0)),
		m_phi_ghost(width+2, height+2, CV_64F, Scalar(0)),
		m_width(width),
		m_height(height)
	{
		//initialize the matrices
		//initialize the spacial step lengths
		m_hx = 2.0 / m_width;
		m_hy = 2.0 / m_height;
	}

	/*
	*	The destructor
	*/
	~CFFT_OT() {};

protected:

	/*
	*	The Gaussian density N(mu,sigma)
	* 
	*	p     - a planar point
	*	sigma - standard deviation
	*	mu	  - mean 
	*/

	double _gaussian(cv::Vec2d p, cv::Vec2d sigma, cv::Vec2d mu )
	{
		cv::Vec2d d;

		for (int i = 0; i < 2; i++) {
			d[i] = (p[i] - mu[i]) / sigma[i];
		}

		double r = -1.0 / 2.0 * cv::norm(d, NORM_L2);
		double z = exp(r) / (2 * M_PI * sigma[0] * sigma[1]);
		return z;
	}

	/* The input is the density matrix, normalize it to make 
	*  the total measure equals to 4.0 
	*/

	void _normalize(cv::Mat& F) {
		int width = F.cols;
		int height = F.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		/* ensure the total measure is 4.0 */
		double sum = 0;
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {
				sum += F.at<double>(i, j) * hx * hy;
			}
		double factor = 4.0 / sum;
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {
				F.at<double>(i, j) *= factor;
			}

	}

	/*
	 *	set the source density to be the Gaussian distribution
	 *	F - the source density matrix
	 */
	void set_source_density_Gaussian( cv::Mat & F ) {

		int width  = F.cols;
		int height = F.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		/* change these parameters */
		//cv::Vec2d sigma(0.15, 0.15);
		cv::Vec2d sigma(0.25, 0.25);
		//cv::Vec2d sigma(0.5, 0.5);
		cv::Vec2d mu(-0.25, -0.25);
		
		// set the Gaussian density
		for( int i = 0; i < width; i ++ )
			for (int j = 0; j < height; j++) {
				cv::Vec2d p = m_solver.coord(i, j, m_width, m_height);
				F.at<double>(i, j) = _gaussian(p, sigma, mu);
			}

		// normalize the density function
		_normalize(F);
	}

	/*
	 *  set the source density to be a Bezier function
	 *	F - the source density matrix
	*/
	void set_source_density_Bezier(cv::Mat& F) {

		int width = F.cols;
		int height = F.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		//the Bezier surface
		CBezierSurface bezier;
		//control points of the Bezier surface
		double z[4][4] = { {1,2,0,1},{2,1,1,2},{1,2,2,1},{2,1,0,1} };

		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {
				//convert the matrix indices to the planar coordinates
				cv::Vec2d p = m_solver.coord(i, j, m_width, m_height);
				//evaluate the Bezier function
				F.at<double>(i, j) = bezier.value(z, (p[0]+1.0)/2.0, (p[1]+1.0)/2.0);
			}

		//normalize the density function
		_normalize(F);
	}

	/*
	 *	Compute the right hand side of equation (6)
	 * 
	 *	phi        - current Kantorovich potential
	 *  phi_ghost  - extended Kantorovich potential with the ghost cells
	 *  Dxx        - phi_{xx}
	 *  Dyy	       - phi_{yy}
	 *  Dxy        - phi_{xy}
	 *	F          - Source density
	 *  rhs        - right hand side of Eqn. (6)
	 */
	void _rhs(const Mat & phi,Mat & phi_ghost,Mat & Dxx,Mat & Dyy,Mat & Dxy,const Mat & F,Mat & rhs) {

		int width  = phi_ghost.cols - 2;
		int height = phi_ghost.rows - 2;

		double h1 = 2.0 / width;
		double h2 = 2.0 / height;

		assert(rhs.cols == phi_ghost.cols - 2 && rhs.rows == phi_ghost.cols - 2);

		//set the ghost cells
		m_solver.set_ghost_cells(phi_ghost, phi);
		//compute the partial derivatives using the central difference operator
		m_solver._dxx(phi_ghost, Dxx, h1);
		m_solver._dyy(phi_ghost, Dyy, h2);
		m_solver._dxy(phi_ghost, Dxy, h1, h2);
		//compute the right hand side
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {

				double dxx = Dxx.at<double>(i,j) + 1.0;
				double dyy = Dyy.at<double>(i,j) + 1.0;
				double dxy = Dxy.at<double>(i,j);
				double f = F.at<double>(i,j);

				double d = m_solver._sqr(dxx) + m_solver._sqr(dyy) + 2.0 * m_solver._sqr(dxy) + 2.0 * f;
				d = sqrt(d) - 2.0;
				rhs.at<double>(i,j) = d;
			}
		
		//normalize the right hand side
		//to be compatible with the Neumann boundary condition

		double sum = 0;
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {
				sum += rhs.at<double>(i, j);
			}

		double mean = sum/width * height;
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {
				rhs.at<double>(i, j) -= mean;
			}
	}

	/*
	*	Compute the gradient map
	*	
	*	phi       - the final Kantorovich potential
	*	phi_ghost - the extended Kantorovich potential
	*	Dx        - phi_{x}
	*   Dy		  - phi_{y}
	*	gradient  - nabla phi
	*/
	void gradient_map(const Mat & phi, Mat & phi_ghost, Mat & Dx, Mat & Dy, Mat & gradient )
	{
		int width  = phi.cols;
		int height = phi.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		m_solver.set_ghost_cells(phi_ghost, phi);
		m_solver._dx(phi_ghost, Dx, hx);
		m_solver._dy(phi_ghost, Dy, hy);

		for (int i = 0; i < width; i++) {
			for (int j = 0; j < height; j++)
			{
				Vec2d  c = m_solver.coord(i, j, width, height);
				Vec2d& p = gradient.at<Vec2d>(i, j);				
				double dx = c[0] + Dx.at<double>(i, j);
				double dy = c[1] + Dy.at<double>(i,j);
				p = p + Vec2d(dx, dy);
			}
		}
	}

	/*
	*	The Monge-Ampere operator det(D^2 u)
	*	
	*	phi       - the input Kantorovich potential
	*	phi_ghost - the extended Kantorovich potential
	*	Dxx       - phi_{xx}
	*   Dyy       - phi_{yy}
	*   Dxy       - phi_{xy}
	*	ma        - det( D^2 u ), u = (x^2+y^2)/2 + phi
	*/

	void _MongeAmpere(const Mat& phi, Mat& phi_ghost, Mat& Dxx, Mat& Dyy, Mat& Dxy, Mat & ma)
	{
		int width = phi.cols;
		int height = phi.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		//set the ghost cells
		m_solver.set_ghost_cells(phi_ghost, phi);
		//compute the Hessian matrix
		m_solver._dxx(phi_ghost, Dxx, hx);
		m_solver._dyy(phi_ghost, Dyy, hx);
		m_solver._dxy(phi_ghost, Dxy, hx, hy);

		//compute the det (D^2 u)
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {

				double dxx = Dxx.at<double>(i, j) + 1;
				double dyy = Dyy.at<double>(i, j) + 1;
				double dxy = Dxy.at<double>(i, j);
				double f = dxx * dyy - dxy * dxy;
				ma.at<double>(i, j) = f;
			}
	}


public:
	/*
	*	for large density function, one can increase the interpolation steps
	*
	*/
	void solve() {

		//set the source density, either using Gaussian distribution or Bezier functions
		//set_source_density_Bezier(m_F);
		set_source_density_Gaussian( m_F );

		cv::Mat old_phi(m_width,m_height,CV_64F,Scalar(0));

			while (true) {
				//compute the right hand side of Eqn. (6) 
				_rhs(m_phi, m_phi_ghost, m_Dxx, m_Dyy, m_Dxy, m_F, m_rhs);
				//Solive the Poisson equation to get the Kantorovich potential
				m_solver.solve(m_phi, m_rhs, m_hx, m_hy);
				//compute the difference between the old and the new potentials
				double diff = cv::norm(old_phi - m_phi) * sqrt(m_hx * m_hy);
				//if the difference is too small, terminate
				if (diff < 1e-14 ) break;
				//update the old potential
				std::cout << diff << " ";
				m_phi.copyTo(old_phi);
			}
		//compute the L2 error between the MA(u) and the given density F
		_MongeAmpere(m_phi, m_phi_ghost, m_Dxx, m_Dyy, m_Dxy, m_rhs);
		double approximation_error = cv::norm(m_rhs - m_F, NORM_L2) * sqrt(m_hx * m_hy);
		std::cout << "\n\nThe Solution L2 Error " << approximation_error << std::endl;
		//compute the gradient map
		gradient_map(m_phi, m_phi_ghost, m_Dx, m_Dy, m_gradient);
	}
	
	/*
	 *	Compute the forward mapping
	 * 
	 *	map     - the OT map
	 *  source  - input image
	 *  target  - the target image
	 */

	void forward_map(const Mat& map, const Mat& source, Mat& target) {

		int width = map.cols;
		int height = map.rows;

		double hx = 2.0 / width;
		double hy = 2.0 / height;

		//forward mapping
		for (int i = 0; i < width; i++)
			for (int j = 0; j < height; j++) {

				Vec2d p = map.at<Vec2d>(i, j);
				int u = int((p[0] - (-1.0 + hx / 2)) / hx + 0.5);
				int v = int((p[1] - (-1.0 + hy / 2)) / hy + 0.5);

				if (u > width - 1 || u < 0) continue;
				if (v > height - 1 || v < 0) continue;

				target.at<double>(u, v) = source.at<double>(i, j);

			}
	}

public:
	//handles to the data members 
	Mat& gradient() { return m_gradient; };
	Mat& source_density() { return m_F;  };
	CPoissonDCT& Poisson_solver() { return m_solver; };

protected:
	/* width and height */
	int m_width, m_height;
	/* spacial step length */
	double m_hx, m_hy;

	/* Source density function*/
	cv::Mat     m_F;
	/* Kantorovich potential */
	cv::Mat     m_phi;
	/* Kantorovich potential with ghost cells */
	cv::Mat     m_phi_ghost;
	/* right hand side   */
	cv::Mat     m_rhs;
	/* gradient map */
	cv::Mat     m_gradient;

	/* derivatives */
	cv::Mat     m_Dxx;
	cv::Mat     m_Dyy;
	cv::Mat     m_Dxy;
	cv::Mat     m_Dy;
	cv::Mat     m_Dx;
	/* The DCT Poisson solver    */
	CPoissonDCT m_solver;
};

#endif
