import numpy as np
from scipy import interpolate
from scipy.integrate import cumtrapz, solve_ivp
from scipy.linalg import logm
from scipy.interpolate import BSpline as SciBSpline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from skfda.representation.grid import FDataGrid
from skfda.representation.basis import BSpline
from skfda import preprocessing
from skfda.misc.regularization import TikhonovRegularization
from skfda.misc.operators import LinearDifferentialOperator



class curve_FrenetSerret_framework():

    def __init__(self, x, scale=True):
        """
            Class containing functions to compute the different elements of the Frenet-Serret framework of a Euclidean curve: arc-length function, derivatives, Frenet paths and Frenet curvatures.

            x: numpy array of shape (N, dim). Euclidean curve in R^dim with N samples points, from which we want to compute the Frenet framework.

        """
        self.x = x
        self.N, self.dim = x.shape
        if self.dim < 2:
            raise Exception("The Frenet Serret framework is defined only for curves in R^d with d >= 2.")
        self.scale_ind = scale
        self.time = np.linspace(0,1,self.N)

    def compute_full_FrenetSerret_framework(self, h_deriv, smoothing_parameter, nb_basis, h=None, method='extrinsic'):
        """
            Function to compute the full Frenet-Serret framework of x (arc-length, Frenet path, curvatures).

            Inputs:
            - h_deriv: float. Parameter for the local polynomial estimation of the derivatives.
            - smoothing_parameter: float. Smoothing parameter for the Bspline smoothing of the raw estimates of the curvatures.
            - nb_basis: int. Number of Bspline basis to smooth the curvatures
            - h: float. Parameter needed if the method is set to 'approx_ode'. Hyperparameter of the kernel used to make a local approximation of the ode.
            - method: 'extrinsic' or 'approx_ode'

        """
        self.compute_derivatives(4, h_deriv)
        self.compute_arc_length()
        self.compute_FrenetPath()
        if method=='extrinsic':
            _, theta_extrinsic_fct = self.theta_extrinsic_formula(smoothing_parameter, nb_basis)
            self.theta = theta_extrinsic_fct
        elif method=='approx_ode':
            if h is None:
                raise Exception("Error: the parameter 'h' must not be set to None if the method is 'approx_ode'.")
            else:
                self.theta = self.theta_approx_ODE(h, smoothing_parameter, nb_basis)
        try:
            Q = self.solve_FrenetSerret_ODE(self.time, self.theta, Q0=self.Q[:,:,0])
            self.reconst_curve = cumtrapz(Q[:,0,:], self.time, initial=0).T
        except:
            print("Issue in the curve reconstruction from theta extrins")


    def kernel(self, x):
        return (3/4)*(1-np.power(x,2))*(np.abs(x)<1)
    
    def local_kernel(self, x, delta):
        return np.power((1 - np.power((np.abs(x)/delta),3)), 3)
    

    def compute_derivatives(self, deg, h):
        """
            Compute the derivatives of x by local polynomial smoothing.

            Inputs:
            - deg: int, degree of the polynomial basis
            - h: float, hyperparameter of the kernel
        """
        pre_process = PolynomialFeatures(degree=deg)
        deriv_estim = np.zeros((self.N,(deg+1)*self.dim))
        for i in range(self.N):
            T = self.time - self.time[i]
            W = self.kernel(T/h)
            T_poly = pre_process.fit_transform(T.reshape(-1,1))
            for j in range(deg+1):
                T_poly[:,j] = T_poly[:,j]/np.math.factorial(j)
            pr_model = LinearRegression(fit_intercept = False)
            pr_model.fit(T_poly, self.x, W)
            B = pr_model.coef_
            deriv_estim[i,:] = B.reshape(1,(deg+1)*self.dim, order='F')
        self.derivatives = np.zeros((self.dim+1, self.N, self.dim))
        for k in range(self.dim+1):
            self.derivatives[k] = deriv_estim[:,k*self.dim:(k+1)*self.dim]


    def compute_arc_length(self):
        """ 
            Compute the arc length function and its derivative.
        """
        sdot = np.linalg.norm(self.derivatives[1], axis=1)
        s_int = cumtrapz(sdot, self.time, initial=0)
        self.L = (s_int.max() - s_int.min()) 
        if self.scale_ind:
            self.arc_s_dot = interpolate.interp1d(self.time, sdot/self.L)
            self.grid_arc_s = (s_int - s_int.min()) / (s_int.max() - s_int.min())
            self.arc_s = interpolate.interp1d(self.time, self.grid_arc_s)
            self.x = self.x/self.L
        else:
            self.arc_s_dot = interpolate.interp1d(self.time, sdot)
            self.arc_s = interpolate.interp1d(self.time, s_int)
            self.grid_arc_s = s_int - s_int.min()


    def GramSchmidt(self, derivatives):
        """ 
            Do the Gram-Schmidt orthogonalization of the derivatives. 
        """
        vect_frenet_frame = np.zeros((self.dim, self.dim))
        vect_frenet_frame[0] = derivatives[0]/np.linalg.norm(derivatives[0])
        for k in range(1,self.dim):
            vect_frenet_frame[k] = derivatives[k] 
            for j in range(0,k):
                vect_frenet_frame[k] = vect_frenet_frame[k] - np.dot(np.transpose(vect_frenet_frame[j]),derivatives[k])*vect_frenet_frame[j]
            vect_frenet_frame[k] = vect_frenet_frame[k]/np.linalg.norm(vect_frenet_frame[k])
        Q = np.stack(vect_frenet_frame)
        if np.linalg.det(Q)<0:
            vect_frenet_frame[-1] = - vect_frenet_frame[-1]
            Q = np.stack(vect_frenet_frame)
        return np.transpose(Q)
        

    def compute_FrenetPath(self):
        """
            Compute the Frenet path by Gram-Schmidt orthogonalization.
        """
        if self.scale_ind:
            self.unif_grid = self.time
        else:
            self.unif_grid = self.time*self.L

        Q = np.zeros((self.dim, self.dim, self.N))
        Q_unif_param = np.zeros((self.dim, self.dim, self.N))
        for i in range(self.N):
            Qi = self.GramSchmidt(self.derivatives[1:,i,:])
            Q[:,:,i]= Qi
        for k in range(self.dim):
            Q_unif_param[:,k,:] = interpolate.griddata(self.grid_arc_s, Q[:,k,:].T, self.unif_grid, method='cubic').T
        self.Q = Q
        self.X_unifparam = interpolate.griddata(self.grid_arc_s, self.derivatives[0], self.unif_grid, method='cubic')
        self.Q_unifparam = Q_unif_param


    def theta_extrinsic_formula(self, smoothing_parameter=0.01, nb_basis=30):
        """
            Compute the B-spline smoothing of the raw estimate of Frenet curvatures computed using the extrinsic formulas. Available in dimension 2 or 3.
        """

        if self.dim==3:
            theta = np.zeros((self.dim-1, self.N))
            crossvect = np.zeros((self.N,self.dim))
            norm_crossvect = np.zeros(self.N)
            for i in range(self.N):
                crossvect[i,:] = np.cross(self.derivatives[1,i,:],self.derivatives[2,i,:])
                norm_crossvect[i] = np.linalg.norm(crossvect[i,:])
                theta[0,i] = norm_crossvect[i]/np.power(np.linalg.norm(self.derivatives[1,i,:]),3)
                theta[1,i] = (np.dot(crossvect[i,:],np.transpose(self.derivatives[3,i,:])))/(norm_crossvect[i]**2)
            if self.scale_ind:
                theta = theta*self.L

            curv_fct = self.Bspline_smoothing(self.grid_arc_s, theta[0], smoothing_parameter, nb_basis)
            tors_fct = self.Bspline_smoothing(self.grid_arc_s, theta[1], smoothing_parameter, nb_basis)
            theta_fct = lambda t: np.array([curv_fct(t), tors_fct(t)])

        elif self.dim == 2:
            theta = np.zeros((self.dim-1, self.N))
            crossvect = np.zeros((self.N))
            norm_crossvect = np.zeros(self.N)
            for i in range(self.N):
                crossvect[i] = np.squeeze(np.cross(self.derivatives[1,i,:],self.derivatives[2,i,:]))
                norm_crossvect[i] = crossvect[i]
                theta[0,i] = norm_crossvect[i]/np.power(np.linalg.norm(self.derivatives[1,i,:]),3)
            if self.scale_ind:
                theta = theta*self.L
            curv_fct = self.Bspline_smoothing(self.grid_arc_s, theta[0], smoothing_parameter, nb_basis)
            theta_fct = lambda t: np.array([curv_fct(t)])

        else:
            print("Only availabe for curve in dimension 2 or 3.")
        
        return theta, theta_fct
    
    
    def solve_FrenetSerret_ODE(self, t_eval, theta, Q0=None):
        """
            Solve the Frenet-Serret ode with parameter theta evaluated at time points 't_eval'. Return the Frenet path. 

            Inputs:
            - t_eval: numpy array of shape (N,). Grid of N points to evaluate the solution.
            - theta: function. Frenet curvatures functions such that theta(s) is a numpy array of size (dim-1) (the number of Frenet curvatures)
        """
        if Q0 is None:
            Q0 = np.eye(self.dim)
        A_theta = lambda s: - np.diag(theta(s), 1) + np.diag(theta(s), -1)
        ode = lambda s,Q: (np.matmul(Q.reshape(self.dim,self.dim),A_theta(s))).flatten()
        sol = solve_ivp(ode, t_span=(t_eval[0], t_eval[-1]), y0=Q0.flatten(), t_eval=t_eval, method='Radau')
        Q = sol.y.reshape(self.dim,self.dim,len(t_eval))
        return Q
    
    
    def theta_approx_ODE(self, h, smoothing_parameter=0.01, nb_basis=30, adaptive_h=False):
        """
            Compute the B-spline smoothing of the raw estimate of Frenet curvatures computed using a local approximation of the Frenet-Serret ode.

            Inputs:
            - h: float. Hyperparameter of the kernel used to make the local approximation.
            - smoothing_parameter: float. Hyperparameter for the amount of smoothing of the raw estimates
            - nb_basis: int. Number of B-spline basis functions.

        """
        mTheta, mS, mOmega = self.__compute_raw_curvatures(self.unif_grid, self.Q_unifparam, h, adaptive_h)  
        if self.dim == 2:
            curv = self.Bspline_smoothing(mS, mTheta, smoothing_parameter, nb_basis, order=3, weights=mOmega)
            theta_smooth = lambda s: np.array([curv(s)])
        else:
            theta_fct = np.empty((self.dim-1), dtype=object)
            for i in range(self.dim-1):
                theta_fct[i] = self.Bspline_smoothing(mS, mTheta[i], smoothing_parameter, nb_basis, order=3, weights=mOmega)
            theta_smooth = lambda s: np.array([theta_fct[i](s) for i in range(self.dim-1)])

        return theta_smooth
    
    
    """ Utils functions """


    def __skfda_to_scipy_bspline(self, knots_skfda, order_skfda, coefs):
        knots = np.concatenate((np.repeat(knots_skfda[0], order_skfda - 1),knots_skfda,np.repeat(knots_skfda[-1], order_skfda - 1),))
        return SciBSpline(knots, coefs, order_skfda - 1)


    def Bspline_smoothing(self, grid_pts, data_pts, smoothing_parameter, nb_basis, order=3, weights=None):
     
        basis = BSpline(domain_range=(grid_pts[0],grid_pts[-1]), n_basis=nb_basis, order=order+1)
        fd = FDataGrid(data_matrix=data_pts, grid_points=grid_pts, extrapolation="bounds")
        if weights is None:
            smoother = preprocessing.smoothing.BasisSmoother(basis, smoothing_parameter=smoothing_parameter, 
                        regularization=TikhonovRegularization(LinearDifferentialOperator(2)), return_basis=True, method='cholesky')
        else:
            weights = np.diag(weights)
            smoother = preprocessing.smoothing.BasisSmoother(basis, smoothing_parameter=smoothing_parameter, 
                        regularization=TikhonovRegularization(LinearDifferentialOperator(2)), weights=weights, return_basis=True, method='cholesky')
        fd_basis = smoother.fit_transform(fd)
        def f(x):
            return self.__skfda_to_scipy_bspline(basis.knots, order+1, fd_basis.coefficients.squeeze())(x)
        
        return f


    def __my_log_M3(self, R):
        """
            Compute the matrix logarithm in R^3, with Rodrigues Formula.
        
        """
        N = np.linalg.norm(R-np.eye(len(R)))
        if np.isnan(N) or np.isinf(N) or N<10e-6:
            return np.zeros((len(R),len(R)))
        else:
            vecA = np.zeros(3)
            c = 0.5*(np.trace(R)-1)
            if c>0:
                trR = min(c,1)
            else:
                trR = max(c,-1)
            theta = np.arccos(trR)
            if np.abs(theta)>10e-6:
                beta = theta/(2*np.sin(theta))
            else:
                beta = 0.5 * (1 + np.square(theta)/6)

            vecA[0]= -R[0,1]+R[1,0]
            vecA[1]= R[0,2]-R[2,0]
            vecA[2]= -R[1,2]+R[2,1]
            vecA = beta*vecA
            return np.array([[0, -vecA[0], vecA[1]], [vecA[0], 0, -vecA[2]], [-vecA[1], vecA[2], 0]])



    def __compute_neighbors(self, grid, h, adaptive=False):

        neighbor_obs = []
        weight = []
        grid_double = []
        delta = []
        nb_grid = len(grid)

        if adaptive:
            for q in range(nb_grid):
                t_q = grid[q]
                delta_s = abs(grid-t_q)
                D = 1.0001*np.sort(delta_s)[h-1]
                neighbor_obs.append(np.argsort(delta_s)[:h]) # index of observations in the neighborhood of t_q
                weight.append((1/D)*self.local_kernel((t_q - grid[neighbor_obs[q]]), D)) # K_h(t_q-s_j, D)
                grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                delta.append(t_q - grid[neighbor_obs[q]])  # t_q-s_j
        else:
            val_min = np.min(grid)
            val_max = np.max(grid)
            for q in range(nb_grid):
                t_q = grid[q]
                if t_q-val_min < h and q!=0:
                    h_bis = np.abs(t_q-val_min) + 10e-10
                    neighbor_obs.append(np.where(abs(grid - t_q) <= h_bis)[0])
                    weight.append((1/h)*self.kernel((t_q - grid[neighbor_obs[q]])/h))
                    grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                    delta.append(t_q - grid[neighbor_obs[q]])
                elif val_max-t_q < h and q!=nb_grid-1:
                    h_bis = np.abs(val_max-t_q) + 10e-10
                    neighbor_obs.append(np.where(abs(grid - t_q) <= h_bis)[0])
                    weight.append((1/h)*self.kernel((t_q - grid[neighbor_obs[q]])/h))
                    grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                    delta.append(t_q - grid[neighbor_obs[q]])
                elif q==0:
                    neighbor_obs.append(np.array([0,1]))
                    weight.append((1/h)*self.kernel((t_q - grid[neighbor_obs[q]])/h))
                    grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                    delta.append(t_q - grid[neighbor_obs[q]])
                elif q==nb_grid-1:
                    neighbor_obs.append(np.array([len(grid)-2,len(grid)-1]))
                    weight.append((1/h)*self.kernel((t_q - grid[neighbor_obs[q]])/h))
                    grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                    delta.append(t_q - grid[neighbor_obs[q]])
                else:
                    neighbor_obs.append(np.where(abs(grid - t_q) <= h)[0]) # index of observations in the neighborhood of t_q
                    weight.append((1/h)*self.kernel((t_q - grid[neighbor_obs[q]])/h)) # K_h(t_q-s_j)
                    grid_double.append((t_q + grid[neighbor_obs[q]])/2) # (t_q+s_j)/2
                    delta.append(t_q - grid[neighbor_obs[q]])  # t_q-s_j

        neighbor_obs = np.squeeze(neighbor_obs)
        weight = np.squeeze(np.asarray(weight))
        grid_double = np.squeeze(np.asarray(grid_double))
        delta = np.squeeze(np.asarray(delta))
        return neighbor_obs, weight, grid_double, delta
           

    def __compute_sort_unique_val(self, S, Omega, Theta):
        """
            Step of function Compute Raw Curvature, compute the re-ordering of the data.
        
        """
        uniqueS = np.unique(S)
        nb_unique_val = len(uniqueS)
        mOmega = np.zeros(nb_unique_val)
        mTheta = np.zeros((self.dim-1,nb_unique_val))
        for ijq in range(nb_unique_val):
            id_ijq = np.where(S==uniqueS[ijq])[0]
            Omega_ijq = Omega[id_ijq]
            mOmega[ijq] = np.sum(Omega_ijq)
            if mOmega[ijq]>0:
                for k in range(self.dim-1):
                    Thetak_ijq = Theta[k][id_ijq]
                    mTheta[k,ijq] = (np.ascontiguousarray(Omega_ijq[np.where(Omega_ijq>0)]) @ np.ascontiguousarray(np.transpose(Thetak_ijq[np.where(Omega_ijq>0)])))/mOmega[ijq]
        return uniqueS, mOmega, mTheta
    

    def __compute_Rq_boucle(self, N_q, Obs_q, data, u_q, q, nb_grid):
        R_q = np.zeros((self.dim,self.dim,N_q))
        if self.dim==3:
            for j in range(N_q):
                if (q!=0 or j!=0) and (q!=nb_grid-1 or j!=N_q-1):
                    R_q[:,:,j] = - self.__my_log_M3(np.transpose(np.ascontiguousarray(data))@np.ascontiguousarray(Obs_q[:,:,j]))/u_q[j]
            return R_q
        else:
            for j in range(N_q):
                if (q!=0 or j!=0) and (q!=nb_grid-1 or j!=N_q-1):
                    R_q[:,:,j] = -logm(np.transpose(np.ascontiguousarray(data))@np.ascontiguousarray(Obs_q[:,:,j]))/u_q[j]
            return R_q
    
    def __compute_raw_curvatures(self, grid, Q, h, adaptive):
        """
            Compute the weighted instantaneous rate of change of the Frenet frames.
        
        """

        neighbor_obs, weight, grid_double, delta = self.__compute_neighbors(grid, h, adaptive=adaptive)
        nb_grid = len(grid)

        Omega, S = [], []
        Theta = [[] for i in range(self.dim-1)]
        for q in range(nb_grid):
            if q==0:
                s = grid[0]*np.ones(len(neighbor_obs[q]))
            elif q==nb_grid-1:
                s = grid[-1]*np.ones(len(neighbor_obs[q]))
            else:
                s = grid_double[q]
            S += list(s)

            N_q = len(neighbor_obs[q])
            Obs_q = Q[:,:,neighbor_obs[q]]
            w_q = weight[q]
            u_q = np.copy(delta[q])
            omega_q = np.multiply(w_q,np.power(u_q,2))
            if q!=0 and q!=nb_grid-1:
                v_q = np.where(u_q==0)[0]
                u_q[u_q==0] = 1
            R_q = self.__compute_Rq_boucle(N_q, Obs_q, Q[:,:,q], u_q, q, nb_grid)

            if q!=0 and q!=nb_grid-1:
                R_q[:,:,v_q] = np.abs(0*R_q[:,:,v_q])

            for i in range(self.dim-1):
                theta_i = np.squeeze(R_q[i+1,i,:])
                Theta[i] = np.append(Theta[i], theta_i.tolist())

            Omega = np.append(Omega, omega_q.tolist())

        Ms, Momega, Mtheta = self.__compute_sort_unique_val(np.around(S, 8), Omega, Theta)

        # Test pour enlever les valeurs à zeros.
        Momega = np.asarray(Momega)
        ind_nozero = np.where(Momega!=0.)
        Momega = np.squeeze(Momega[ind_nozero])
        Mtheta = np.squeeze(np.asarray(Mtheta)[:,ind_nozero])
        Ms = Ms[ind_nozero]

        return Mtheta, Ms, Momega
    

   


 