import numpy as np
import fdasrsf as fs
from scipy import interpolate
from scipy.integrate import cumtrapz, solve_ivp
import optimum_reparamN2 as orN2
from curve_FrenetSerret_framework import curve_FrenetSerret_framework
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from plotly.express.colors import sample_colorscale

color_list = px.colors.qualitative.Plotly
color_list_geod = sample_colorscale('Bluered_r', list(np.linspace(0, 1, 2)))
color_x0 = color_list_geod[0]
color_x1 = color_list_geod[-1]
color_srvf = color_list[2]
color_src = color_list[5]
color_fc = color_list[6]


class curve_shape_analysis():

    def __init__(self, dim):
        """ 
        Class containing functions for computing geodesic equations between two curves in R^d under different representations:
            - SRVF: square root velocity functions
            - Frenet curvatures (FC)
            - SRC Transforms: square root curvature transforms

        Inputs:
            - dim: int 
              dimension of the space considered
        """
        self.dim = dim


    """ Functions for the method SRVF """

    def align_srvf(self, x0, x1):
        """
            Compute the optimal reparameterization under the SRVF framework of x1 to match x0. 

            Inputs:
            - x0, x1: numpy array of shape (N, dim).
        """
        x1_bis, q1_bis, R_opt, h_opt = fs.curve_functions.find_rotation_and_seed_coord(x0.T, x1.T, rotation=True)
        return x1_bis, q1_bis, R_opt, h_opt
    
    def geodesic_srvf(self, x0, x1, k, align=True):
        """
            Compute the geodesic path under the SRVF framework bewteen x0 and x1 with k steps. 

            Inputs:
            - x0, x1: numpy array of shape (N, dim).
            - k: int. Number of steps in the geodesic path.
            - align: if True the optimal reparameterization is computed before computed the geodesic. 
        """
        N = len(x0)
        q0, len0, lenq0 = fs.curve_functions.curve_to_q(x0.T)
        if align:
            x1_align, q1_align, R_opt, h_opt = self.align_srvf(x0, x1)
            x1_align = x1_align.T
        else:
            R_opt = np.eye(self.dim)
            h_opt = np.linspace(0,1,N)
            x1_align = x1
            q1_align, len1, lenq1 = fs.curve_functions.curve_to_q(x1.T)
        dist = self.dist_sphere(q0, q1_align)

        path_q = np.zeros((k, self.dim, N))
        path_x = np.zeros((k, N, self.dim))
        for tau in range(0, k):
            if tau == 0:
                tau1 = 0
            else:
                tau1 = tau / (k - 1.)
            s = dist * tau1
            if dist > 0:
                path_q[tau, :, :] = (np.sin(dist-s)*q0+np.sin(s)*q1_align)/np.sin(dist)
            elif dist == 0:
                path_q[tau, :, :] = (1 - tau1)*q0 + (tau1)*q1_align
            else:
                raise Exception("geod_sphere computed a negative distance")
            path_x[tau, :, :] =  fs.curve_functions.q_to_curve(path_q[tau, :, :]).T

        return dist, path_x, path_q, R_opt, h_opt


    def deform_srvf(self, x, gam, T=None):
        """
            Compute the deformed curve in the SRVF registration problem. 

            Inputs:
            - x: numpy array of shape (N,dim); Euclidean curve
            - gam: numpy array of shape (N,); warping function
            
        """
        N = x.shape[0]
        time = np.linspace(0,1,N)
        if T is None:
            dot_x = np.gradient(x, 1. / (N - 1))
            dot_x = dot_x[0]
            T = np.zeros((N,self.dim))
            for i in range(N):
                L = np.linalg.norm(dot_x[i,:])
                if L > 0.0001:
                    T[i] = dot_x[i] / L
                else:
                    T[i] = dot_x[i] * 0.0001
        gam_warp_fct = interpolate.UnivariateSpline(time, gam, s=0.0001)
        gam_smooth = lambda t: (gam_warp_fct(t) - gam_warp_fct(time).min())/ (gam_warp_fct(time).max() - gam_warp_fct(time).min())
        T_warp = np.sqrt(gam_warp_fct(time,1))*(interpolate.interp1d(np.linspace(0,1,N), T.T)(gam_smooth(time)))
        x_def = cumtrapz(T_warp, np.linspace(0,1,N), initial=0).T
        return x_def

    
    def dist_srvf(self, x0, x1, align=True):
        """
            Compute the SRVF distance bewteen x0 and x1. 

            Inputs:
            - x0, x1: numpy array of shape (N, dim).
            - align: if True the optimal reparameterization is computed before computed the distance. 
        """
        N = len(x0)
        q0, len0, lenq0 = fs.curve_functions.curve_to_q(x0.T)
        if align:
            x1_align, q1_align, R_opt, h_opt = self.align_srvf(x0, x1)
            x1_align = x1_align.T
        else:
            R_opt = np.eye(self.dim)
            h_opt = np.linspace(0,1,N)
            x1_align = x1
            q1_align, len1, lenq1 = fs.curve_functions.curve_to_q(x1.T)
        dist = self.dist_sphere(q0, q1_align)
        return dist
    
    def karcher_mean_srvf(self, arr_x, new_N=None):
        """
            Compute the Karcher mean under the SRVF framework. 
            
            Input:
            - arr_x: numpy array of shape (K, N, dim) of K curves in dimension 'dim' with N samples points 
              Set of Euclidean curves.

        """
        if new_N is None:
            N = arr_x[0].shape[0]
        else:
            N = new_N
        beta = np.zeros((self.dim, arr_x[0].shape[0], len(arr_x)))
        for k in range(len(arr_x)):
            beta[:,:,k] = arr_x[k].T
        obj = fs.curve_stats.fdacurve(beta, N=N)
        obj.karcher_mean()
        return obj.beta_mean.T
    



    """ Functions for the method SRC """

    def warp_src(self, c, gam, smooth=True):
        """
            Compute the group action of a warping function 'gam' of the SRC 'c'
        """
        time = np.linspace(0,1,len(gam))
        if smooth:
            gam_warp_fct = interpolate.UnivariateSpline(time, gam, s=0.0001)
            gam_smooth = lambda t: (gam_warp_fct(t) - gam_warp_fct(time).min())/ (gam_warp_fct(time).max() - gam_warp_fct(time).min())
            c_warp_fct = lambda t: c(gam_smooth(t))*np.sqrt(gam_warp_fct(t,1))
        else:
            g = gam
            g_dev = np.gradient(gam, 1. / (len(gam)-1))
            c_warp = c(g)*np.sqrt(g_dev)
            c_warp_fct = interpolate.interp1d(time, c_warp)
        return c_warp_fct
    

    def align_src(self, c0, c1, time, lam=1, smooth=False):
        """
            Compute the optimal alignment between two SRC representations.

            Inputs:
            - c0, c1: functions such that c0(t) is a numpy array of shape (dim-1); the two SRCs
            - time: numpy array of shape (N,); grid of time points bewteen 0 and 1.
            - lam: float; ponderation coefficient in SRC distance.
            - smooth: if True an additional smoothing of the optimal warping function is made.
        """
        if np.linalg.norm(c0(time)-c1(time),'fro') > 0.0001:
            gam = orN2.coptimum_reparam_curve(np.ascontiguousarray(c0(time)), time, np.ascontiguousarray(c1(time)), lam)
            c1_align = self.warp_src(c1, gam, smooth)
        else:
            gam = time
            c1_align = c1
        return c1_align, gam


    def geodesic_Psi(self, gam0, gam1, k):
        """
            Compute the geodesic bewteen two diffeomorphisms in Diff_+([0,1]) in the space Psi([0,1]).

            Inputs:
            - gam0, gam1: numpy arrays of shape (N,); two diffeomorphisms in Diff_+([0,1])
            - k: int; number of steps along the geodesic path.
        """
        N = gam0.shape[0]
        time = np.linspace(0,1,N)
        binsize = np.mean(np.diff(time))
        psi0 = np.sqrt(np.gradient(gam0,binsize))
        psi1 = np.sqrt(np.gradient(gam1,binsize))
        dist = self.dist_sphere(psi0, psi1)
        path_psi = np.zeros((k, N))
        path_gam = np.zeros((k, N))
        for tau in range(0, k):
            if tau == 0:
                tau1 = 0
            else:
                tau1 = tau / (k - 1.)
            s = dist * tau1
            if dist > 0:
                path_psi[tau, :] = (np.sin(dist-s)*psi0+np.sin(s)*psi1)/np.sin(dist)
            elif dist == 0:
                path_psi[tau, :] = (1 - tau1)*psi0 + (tau1)*psi1
            else:
                raise Exception("geod_sphere computed a negative distance")
            gam = cumtrapz(path_psi[tau, :]*path_psi[tau, :], time, initial=0)
            path_gam[tau,:] = (gam - gam.min()) / (gam.max() - gam.min())
        return dist, path_gam, path_psi
    

    def geodesic_src(self, theta0, theta1, s0, s1, time, k, align=True, smooth=True, lam=1, arr_Q0=None):
        """
            Compute the geodesic path under the SRC framework.

            Inputs:
            - theta0, theta1: functions such that theta_i(t) is a numpy array of shape (dim-1); Frenet curvatures of the two curves from which we want to compute the geodesic.
            - s0, s1: function on [0,1]; Arc-length function of x0 and x1.
            - time: numpy array of shape (N,); grid of N time points between 0 and 1.
            - k: int; number of steps along the geodesic. 
            - lam: float; ponderation coefficient in SRC distance.
            - smooth: if True an additional smoothing of the optimal warping function is made.
            - align: if True the optimal reparameterization is computed before computed the geodesic. 
            - arr_Q0: numpy array of shape (k, dim, dim); possibility to use a list of initial rotation to reconstruct the curves along the geodesic. 
        """
        c0_theta = lambda t: theta0(t)/np.sqrt(np.linalg.norm(theta0(t)))
        c1_theta = lambda t: theta1(t)/np.sqrt(np.linalg.norm(theta1(t)))
        if align:
            _, gamma_opt = self.align_src(c0_theta, c1_theta, time, lam=lam, smooth=smooth)
            tmp_spline = interpolate.UnivariateSpline(time, gamma_opt, s=0.0001)
            gam_smooth = lambda t: (tmp_spline(t) - tmp_spline(time).min())/ (tmp_spline(time).max() - tmp_spline(time).min())
            c1_theta_align = lambda t: c1_theta(gam_smooth(t))*np.sqrt(tmp_spline(t,1))
            s1_h = gam_smooth(s0(time))
        else:
            gamma_opt = time
            s1_h = s1(time)
            c1_theta_align = c1_theta

        dist_gam, path_arc_s, path_psi = self.geodesic_Psi(s0(time), s1_h, k)

        f_s0_dot = interpolate.UnivariateSpline(time, s0(time), s=0.00001)
        s0_dot = lambda t: f_s0_dot(t, 1)

        dist = np.linalg.norm(np.sqrt(s0_dot(time))*(c0_theta(s0(time))-c1_theta_align(s0(time)))) + lam*dist_gam

        path_theta = np.zeros((k, self.dim-1, len(time)))
        path_Q = np.zeros((k, self.dim, self.dim, len(time)))
        path_curves = np.zeros((k, len(time), self.dim))
        for tau in range(0, k):
            if tau == 0:
                tau1 = 0
            else:
                tau1 = tau / (k - 1.)
            theta_tau = lambda s: ((1-tau1)*c0_theta(s) + tau1*c1_theta_align(s))*np.linalg.norm(((1-tau1)*c0_theta(s) + tau1*c1_theta_align(s)))
            sdot_theta_tau = lambda t: theta_tau(s0(t))*s0_dot(t)
            for i in range(len(time)):
                path_theta[tau, :, i] = sdot_theta_tau(time[i])/(path_psi[tau][i]**2)

            if arr_Q0 is None:
                Q = self.solve_FrenetSerret_ODE(time, sdot_theta_tau)
            else:
                Q = self.solve_FrenetSerret_ODE(time, sdot_theta_tau, Q0=arr_Q0[tau])
            path_curves[tau] = cumtrapz(Q[:,0,:], path_arc_s[tau], initial=0).T
            path_Q[tau] = Q

        return dist, path_curves, path_Q, path_theta, path_arc_s, path_psi, gamma_opt


    def deform_src(self, theta, gam):
        """
            Compute the deformed curve in the SRC registration problem. 

            Inputs:
            - theta: function such that theta(t) is a numpy array of shape (dim-1); Frenet curvatures
            - gam: numpy array of shape (N,); warping function
            
        """ 
        time = np.linspace(0,1,len(gam))
        gam_warp_fct = interpolate.UnivariateSpline(time, gam, s=0.0001)
        gam_smooth = lambda t: (gam_warp_fct(t) - gam_warp_fct(time).min())/ (gam_warp_fct(time).max() - gam_warp_fct(time).min())
        theta_warp = lambda t: theta(gam_smooth(t))*gam_warp_fct(t,1)
        time = np.linspace(0,1,len(gam))
        Q = self.solve_FrenetSerret_ODE(time, theta_warp)
        X = cumtrapz(Q[:,0,:], time, initial=0).T
        return X


    def dist_src(self, theta0, theta1, s0, s1, time, align=True, smooth=True, lam=1):
        """
            Compute the geodesic distance under the SRC framework.

            Inputs:
            - theta0, theta1: functions such that theta_i(t) is a numpy array of shape (dim-1); Frenet curvatures of the two curves from which we want to compute the geodesic distance.
            - s0, s1: function on [0,1]; Arc-length function of x0 and x1.
            - time: numpy array of shape (N,); grid of N time points between 0 and 1.
            - lam: float; ponderation coefficient in SRC distance.
            - smooth: if True an additional smoothing of the optimal warping function is made.
            - align: if True the optimal reparameterization is computed before computed the geodesic. 
        """
        c0_theta = lambda t: theta0(t)/np.sqrt(np.linalg.norm(theta0(t)))
        c1_theta = lambda t: theta1(t)/np.sqrt(np.linalg.norm(theta1(t)))
        if align:
            _, gamma_opt = self.align_src(c0_theta, c1_theta, time, lam=lam, smooth=smooth)
            gamma_opt = (gamma_opt - gamma_opt.min())/(gamma_opt.max() - gamma_opt.min())
            binsize = np.mean(np.diff(time))
            psi_gamma_opt = np.sqrt(np.gradient(gamma_opt,binsize))
            s1_h = np.interp(s0(time), time, gamma_opt)
            s1_h = (s1_h - s1_h.min())/(s1_h.max() - s1_h.min())
            c1_theta_align_s0 = c1_theta(s1_h)*psi_gamma_opt
        else:
            gamma_opt = time
            s1_h = s1(time)
            c1_theta_align_s0 = c1_theta(s0(time))
        
        binsize = np.mean(np.diff(time))
        psi0 = np.sqrt(np.gradient(s0(time),binsize))
        psi1 = np.sqrt(np.gradient(s1_h,binsize))
        dist_gam = self.dist_sphere(psi0, psi1)
        dist = np.linalg.norm(psi0*(c0_theta(s0(time))-c1_theta_align_s0)) + lam*dist_gam    
        return dist
    
    def karcher_mean_src(self, arr_theta, arr_arc_s, tol, max_iter, lam=1):
        """
            Karcher mean under the square-root curvature transform framework. 

            Input:
            - arr_theta: array of K Frenet curvatures functions such that arr_theta[k](s) is a numpy array of size (dim-1) (the number of Frenet curvatures)
              Set of Frenet curvature functions of the arc-length parameter (not of the time) of the Euclidean curves considered for computing the mean.
            - arr_arc_s: numpy array of size (K,N) of the K arc-length functions of the considered curves, with N samples points. 
            - tol: float 
              tolerance for the difference between iterative error. If |error_k - error_{k-1}| < tol, the iteration is stop. 
            - max_iter: int
              Number of maximum iterations to find the optimal mean.
            - lam: float
              Parameter used to add a ponteration in the SRC distance. 

        """
        print('Computing Karcher Mean of 20 curves in SRC space.. \n')
        n = len(arr_theta)
        T = len(arr_arc_s[0])
        time = np.linspace(0,1,T)
        binsize = np.mean(np.diff(time))

        arr_c = np.zeros((n,self.dim-1,T))
        arr_psi = np.zeros(arr_arc_s.shape)
        for i in range(n):
            arr_psi[i] = np.sqrt(np.gradient(arr_arc_s[i],binsize))
            for j in range(T):
                arr_c[i,:,j] = arr_psi[i,j]*arr_theta[i](arr_arc_s[i,j])/np.sqrt(np.linalg.norm(arr_theta[i](arr_arc_s[i,j])))
        mean_c = np.mean(arr_c, axis=0)
        mean_psi = np.mean(arr_psi, axis=0)
        
        dist_arr = np.zeros(n)
        for i in range(n):
            dist_arr[i] = np.linalg.norm(mean_c - arr_c[i]) + lam*self.dist_sphere(mean_psi, arr_psi[i]) 
        ind = np.argmin(dist_arr)
        
        temp_mean_psi = arr_psi[ind]
        temp_mean_c = arr_c[ind]
        temp_error = np.linalg.norm((mean_c - temp_mean_c)) + lam*self.dist_sphere(mean_psi, temp_mean_psi) 
        up_err = temp_error
        k = 0
        print('Iteration ', k, '/', max_iter, ': error ', temp_error)
        while up_err > tol and k < max_iter:
            arr_c_align = np.zeros((n,self.dim-1,T))
            arr_arc_align = np.zeros((n,T))
            arr_psi_align = np.zeros((n,T))
            for i in range(n):
                if np.linalg.norm(temp_mean_c - arr_c[i],'fro') > 0.0001:
                    h_opt = orN2.coptimum_reparam_curve(np.ascontiguousarray(temp_mean_c), time, np.ascontiguousarray(arr_c[i]), lam)
                else:
                    h_opt = time
                h_opt = (h_opt - h_opt.min())/(h_opt.max() - h_opt.min())
                si_h = np.interp(h_opt, time, arr_arc_s[i])
                arr_arc_align[i] = (si_h - si_h.min())/(si_h.max() - si_h.min())
                arr_psi_align[i] = np.sqrt(np.gradient(arr_arc_align[i],binsize))
                for j in range(T):
                    arr_c_align[i,:,j] = arr_psi_align[i,j]*arr_theta[i](arr_arc_align[i,j])/np.sqrt(np.linalg.norm(arr_theta[i](arr_arc_align[i,j])))

            mean_c = np.mean(arr_c_align, axis=0)
            mean_psi = np.mean(arr_psi_align, axis=0)
            error = np.linalg.norm((mean_c - temp_mean_c)) + lam*self.dist_sphere(mean_psi, temp_mean_psi) 
            up_err = abs(temp_error - error)
            temp_error = error
            k += 1
            print('Iteration ', k, '/', max_iter, ': error ', temp_error)
            temp_mean_psi = mean_psi
            temp_mean_c = mean_c
        
        print('Number of iterations', k, '\n')
        
        mean_s = cumtrapz(temp_mean_psi*temp_mean_psi, time, initial=0)
        mean_s = (mean_s - mean_s.min())/(mean_s.max() - mean_s.min())
        mean_theta = np.zeros(temp_mean_c.shape)
        for j in range(T):
            x = mean_c[:,j]/mean_psi[j]
            mean_theta[:,j] = x*np.linalg.norm(x)

        theta = lambda t: interpolate.griddata(mean_s, mean_theta.T, t, method='cubic').T
        Q = self.solve_FrenetSerret_ODE(mean_s, theta)
        mean_x = cumtrapz(Q[:,0,:], mean_s, initial=0).T

        return mean_x, theta, mean_s
    



    
    """ Functions for the method Frenet Curvatures """


    def geodesic_frenet_curvatures(self, theta0, theta1, s0, s1, time, k, arr_Q0=None):
        """
            Compute the geodesic path under the Frenet curvatures framework.

            Inputs:
            - theta0, theta1: functions such that theta_i(t) is a numpy array of shape (dim-1); Frenet curvatures of the two curves from which we want to compute the geodesic.
            - s0, s1: function on [0,1]; Arc-length function of x0 and x1.
            - time: numpy array of shape (N,); grid of N time points between 0 and 1.
            - k: int; number of steps along the geodesic. 
            - arr_Q0: numpy array of shape (k, dim, dim); possibility to use a list of initial rotation to reconstruct the curves along the geodesic. 
        """
        s1_inv = fs.utility_functions.invertGamma(s1)
        h_opt = np.interp(s0, time, s1_inv)
        path_theta = np.zeros((k, self.dim-1, len(time)))
        path_theta_fct = np.empty((k), dtype=object)
        path_Q = np.zeros((k, self.dim, self.dim, len(time)))
        path_curves = np.zeros((k, len(time), self.dim))

        dist = np.linalg.norm((theta0(time)-theta1(time)))
        for tau in range(0, k):
            if tau == 0:
                tau1 = 0
            else:
                tau1 = tau / (k - 1.)
            theta_tau = lambda t: (1-tau1)*theta0(t) + tau1*theta1(t)
            path_theta_fct[tau] = theta_tau
            path_theta[tau] = theta_tau(time)

            if arr_Q0 is None:
                Q = self.solve_FrenetSerret_ODE(time, theta_tau)
            else:
                Q = self.solve_FrenetSerret_ODE(time, theta_tau, Q0=arr_Q0[tau])
            path_curves[tau] = cumtrapz(Q[:,0,:], time, initial=0).T
            path_Q[tau] = Q

        return dist, path_curves, path_Q, path_theta, path_theta_fct, h_opt

    def dist_frenet_curvatures(self, theta0, theta1, time):
        """
            Compute the geodesic distance under the Frenet curvatures framework.

            Inputs:
            - theta0, theta1: functions such that theta_i(t) is a numpy array of shape (dim-1); Frenet curvatures of the two curves from which we want to compute the geodesic.
            - time: numpy array of shape (N,); grid of N time points between 0 and 1.
        """
        dist = np.linalg.norm((theta0(time)-theta1(time)))
        return dist

    def karcher_mean_frenet_curvatures(self, arr_theta, arr_arc_s):
        """
            Compute the Karcher mean under the Frenet curvatures representation. 

            Input:
            - arr_theta: array of K Frenet curvatures functions such that arr_theta[k](s) is a numpy array of size (dim-1) (the number of Frenet curvatures)
              Set of Frenet curvature functions of the arc-length parameter (not of the time) of the Euclidean curves considered for computing the mean.
            - arr_arc_s: numpy array of size (K,N) of the K arc-length functions of the considered curves, with N samples points. 

        """
        n = len(arr_theta)
        mean_theta = lambda s: np.mean([arr_theta[i](s) for i in range(n)], axis=0)
        psi_mu, gam_mu, psi_arr, vec = fs.utility_functions.SqrtMean(arr_arc_s.T)
        Q = self.solve_FrenetSerret_ODE(gam_mu, mean_theta)
        mean_x = cumtrapz(Q[:,0,:], gam_mu, initial=0).T
        return mean_x, mean_theta, gam_mu
    

    """ Utils functions """


    def smooth_gam(self, time, gam, lam):
        tmp_spline = interpolate.UnivariateSpline(time, gam, s=lam)
        gam_smooth = tmp_spline(time)
        gam_smooth = (gam_smooth - gam_smooth.min()) / (gam_smooth.max() - gam_smooth.min())
        gam_smooth_dev = tmp_spline(time, 1)
        return gam_smooth, gam_smooth_dev
    
    def gamma_to_h(self, gamma, s0, s1):
        time = np.linspace(0,1,len(s0))
        s1_inv = fs.utility_functions.invertGamma(s1)
        gam_s0 = np.interp(s0, time, gamma)
        h = np.interp(gam_s0, time, s1_inv)
        return h

    def h_to_gamma(self, h, s0, s1):
        time = np.linspace(0,1,len(s0))
        s0_inv = fs.utility_functions.invertGamma(s0)
        s1_h = np.interp(h, time, s1)
        gamma = np.interp(s0_inv, time, s1_h)
        return gamma
    
    def dist_sphere(self, obj0, obj1):
        """
            Geodesic distance on the sphere. 
        """
        N = obj0.shape[-1]
        val = np.sum(np.sum(obj0 * obj1))/N
        if val > 1:
            if val < 1.001: # assume numerical error
                # import warnings
                # warnings.warn(f"Corrected a numerical error in geod_sphere: rounded {val} to 1")
                val = 1
            else:
                raise Exception(f"innerpod_q2 computed an inner product of {val} which is much greater than 1")
        elif val < -1:
            if val > -1.001: # assume numerical error
                # import warnings
                # warnings.warn(f"Corrected a numerical error in geod_sphere: rounded {val} to -1")
                val = -1
            else:
                raise Exception(f"innerpod_q2 computed an inner product of {val} which is much less than -1")
        dist = np.arccos(val)
        if np.isnan(dist):
            raise Exception("geod_sphere computed a dist value which is NaN")
        return dist


    def solve_FrenetSerret_ODE(self, t_eval, theta, Q0=None):
        """
            Compute the solution of the Frenet Serret ode.
        """
        if Q0 is None:
            Q0 = np.eye(self.dim)
        else:
            Q0 = Q0
        A_theta = lambda s: - np.diag(theta(s), 1) + np.diag(theta(s), -1)
        ode = lambda s,Q: (np.matmul(Q.reshape(self.dim,self.dim),A_theta(s))).flatten()
        sol = solve_ivp(ode, t_span=(0,1), y0=Q0.flatten(), t_eval=t_eval, method='Radau')
        Q = sol.y.reshape(self.dim,self.dim,len(t_eval))
        return Q


    def estimate_theta_along_geodesic(self, x_arr, h_deriv, smoothing_parameter, nb_basis, h=None, method='extrinsic'):
        """
            Estimate of the Frenet curvatures of curves along the geodesic. 
        """
        K = len(x_arr)
        N = len(x_arr[0])
        time = np.linspace(0,1,N)
        theta = np.zeros((K, self.dim-1, N))
        for k in range(K):
            xk_FrenetSerret_framework = curve_FrenetSerret_framework(x_arr[k])
            xk_FrenetSerret_framework.compute_full_FrenetSerret_framework(h_deriv, smoothing_parameter, nb_basis, h, method)
            theta[k] = xk_FrenetSerret_framework.theta(time)
        return theta

   
    def center_3D(self, x_arr, rotation):
        """
            Center a set of 3D curves and find optimal rotation between them. 

            Inputs:
            - x_arr: numpy array of shape (k,N,dim). Set of k curves in R^dim with N samples points
            - rotation: if True the curves are optimally rotated. 
        """
        x_arr_cent = np.zeros((x_arr.shape))
        K = x_arr.shape[0]
        N = x_arr[0].shape[0]
        x_max, y_max, z_max = [], [], []
        for k in range(K):
            centroid = fs.curve_functions.calculatecentroid(x_arr[k].T)
            x_arr_cent[k] = x_arr[k] - np.tile(centroid, [N, 1])
            x_max.append(np.max(abs(x_arr_cent[k][:,0])))
            y_max.append(np.max(abs(x_arr_cent[k][:,1])))
            z_max.append(np.max(abs(x_arr_cent[k][:,2])))
            if rotation:
                new_x, R = fs.curve_functions.find_best_rotation(x_arr_cent[0].T, x_arr_cent[k].T)
                x_arr_cent[k] = new_x.T
        x, y, z = np.max(x_max), np.max(y_max), np.max(z_max)
        return x_arr_cent, [x,y,z]

    def center_2D(self, x_arr, rotation):
        """
            Center a set of 2D curves and find optimal rotation between them. 

            Inputs:
            - x_arr: numpy array of shape (k,N,dim). Set of k curves in R^dim with N samples points
            - rotation: if True the curves are optimally rotated. 
        """
        x_arr_cent = np.zeros((x_arr.shape))
        K = x_arr.shape[0]
        N = x_arr[0].shape[0]
        x_max, y_max = [], []
        for k in range(K):
            centroid = fs.curve_functions.calculatecentroid(x_arr[k].T)
            x_arr_cent[k] = x_arr[k] - np.tile(centroid, [N, 1])
            x_max.append(np.max(abs(x_arr_cent[k][:,0])))
            y_max.append(np.max(abs(x_arr_cent[k][:,1])))
            if rotation:
                new_x, R = fs.curve_functions.find_best_rotation(x_arr_cent[0].T, x_arr_cent[k].T)
                x_arr_cent[k] = new_x.T
        x, y = np.max(x_max), np.max(y_max)
        return x_arr_cent, [x,y]


    def plot_init_curves_geodesic(self, x0, x1):
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, 2)))
        color0 = color_list[0]
        color1 = color_list[1]
        if self.dim == 2:
            fig = make_subplots(rows=1, cols=2)
            fig.add_trace(go.Scatter(x=x0[:,0], y=x0[:,1], mode='lines', name='x_0', line=dict(width=2, color=color0)), row=1, col=1)
            fig.add_trace(go.Scatter(x=x1[:,0], y=x1[:,1], mode='lines', name='x_1', line=dict(width=2, color=color1)), row=1, col=2)
            fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, zerolinecolor='black', zerolinewidth=1)
            fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, zerolinecolor='black', zerolinewidth=1)
            fig.update_layout(showlegend=True, height=400, width=800)
            fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.show()
        if self.dim == 3:
            fig = go.Figure(layout=go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.add_trace(go.Scatter3d(x=x0[:,0],y=x0[:,1],z=x0[:,2],name='x_0',mode='lines',line=dict(width=5,color=color0)))
            fig.add_trace(go.Scatter3d(x=x1[:,0],y=x1[:,1],z=x1[:,2],name='x_1',mode='lines',line=dict(width=5,color=color1)))
            fig.update_layout(legend=dict(orientation="h",yanchor="top",y=1.2,xanchor="right", x=1),
                            scene = dict(xaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),
                                        yaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),
                                        zaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),),
                                        height=600, width=600)
            fig.show()


    def plot_geodesic_theta(self, arr_s, arr_theta, height=500, title=''):
        k = len(arr_theta)
        n = len(arr_theta[0])
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        for i in range(n):
            fig = make_subplots(rows=1, cols=k, shared_xaxes=True, shared_yaxes=True)
            for j in range(k):
                fig.add_trace(go.Scatter(x=arr_s[j], y=arr_theta[j,i], mode='lines', line=dict(width=2, color=color_list[j])), row=1, col=j+1)
                fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
                fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
            fig.update_layout(title_text=title+" geodesic Frenet curvature "+str(i+1), showlegend=False, height=height)
            fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.show()


    def plot_comparison_one_theta(self, s_arr, theta_arr, height):
        N = len(s_arr)
        k = len(s_arr[0])
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        fig = make_subplots(rows=N, cols=k, shared_xaxes=True, shared_yaxes=True)
        for i in range(N):
            for j in range(k):
                fig.add_trace(go.Scatter(x=s_arr[i][j], y=theta_arr[i][j], mode='lines', line=dict(width=2, color=color_list[j])), row=i+1, col=j+1)
                fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
                fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
        fig.update_layout(title_text="Comparison geodesic curvature", showlegend=False, height=height)
        fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
        fig.show()


    def plot_comparison_theta(self, list_arr_s, list_arr_theta, height=500):
        N = 3 # 3 different methods
        K = len(list_arr_theta[0]) # number of curves along the geodesic
        n = self.dim-1 # number of Frenet curvatures
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, K)))
        for i in range(n):
            fig = make_subplots(rows=N, cols=K, shared_xaxes=True, shared_yaxes=True)
            for j in range(N):
                for k in range(K):
                    fig.add_trace(go.Scatter(x=list_arr_s[j][k], y=list_arr_theta[j][k,i], mode='lines', line=dict(width=2, color=color_list[k])), row=j+1, col=k+1)
                    fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
                    fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=True, zerolinecolor='black', zerolinewidth=1)
            fig.update_layout(title_text="Comparison geodesic curvature number "+str(i+1)+": SRVF (first row), SRC (second row), Frenet curvatures (third row)", showlegend=False, height=height)
            fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.show()


    def plot_geodesic_2D(self, x_arr, height=500):
        k = len(x_arr)
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        fig = make_subplots(rows=1, cols=k, shared_xaxes=True, shared_yaxes=True)
        for j in range(k):
            fig.add_trace(go.Scatter(x=x_arr[j,:,0], y=x_arr[j,:,1], mode='lines', line=dict(width=2, color=color_list[j])), row=1, col=j+1)
            fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, zerolinecolor='black', zerolinewidth=1)
            fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, zerolinecolor='black', zerolinewidth=1)
        fig.update_layout(showlegend=False, height=height)
        fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
        fig.show()


    def plot_geodesic_3D(self, x_arr):
        x_arr_cent, max_coord = self.center_3D(x_arr, False)
        k = len(x_arr)
        # color_list = px.colors.qualitative.Plotly
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        fig = make_subplots(rows=1, cols=k, specs=[[{"type": "scene"} for i in range(k)]],)
        for i in range(k):
            fig.add_trace(go.Scatter3d(x=x_arr_cent[i][:,0], y=x_arr_cent[i][:,1], z=x_arr_cent[i][:,2], mode="lines", line=dict(width=5, color=color_list[i])), row=1, col=i+1)
            fig.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[0], max_coord[0]]),
                                yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[1], max_coord[1]]),
                                zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[2], max_coord[2]]),), aspectmode='cube', row=1, col=i+1)
        fig.show()


    def plot_comparison_2D(self, path_x_srvf, path_x_src, path_x_fc, rotation):
        k = len(path_x_fc)
        x_srvf_cent, coord_srvf = self.center_2D(path_x_srvf, rotation)
        x_src_cent, coord_src = self.center_2D(path_x_src, rotation)
        x_fc_cent, coord_fc = self.center_2D(path_x_fc, rotation)
        x, y = np.max([coord_srvf[0], coord_src[0], coord_fc[0]]), np.max([coord_srvf[1], coord_src[1], coord_fc[1]])
        max_coord = [x, y]
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        fig = make_subplots(rows=3, cols=k, shared_xaxes=True, shared_yaxes=True)
        for j in range(k):
            # SRVF
            fig.add_trace(go.Scatter(x=x_srvf_cent[j,:,0], y=x_srvf_cent[j,:,1], mode='lines', line=dict(width=2, color=color_list[j])), row=1, col=j+1)
            fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[0], max_coord[0]], row=1, col=j+1)
            fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[1], max_coord[1]], row=1, col=j+1)
            # SRC
            fig.add_trace(go.Scatter(x=x_src_cent[j,:,0], y=x_src_cent[j,:,1], mode='lines', line=dict(width=2, color=color_list[j])), row=2, col=j+1)
            fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[0], max_coord[0]], row=2, col=j+1)
            fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[1], max_coord[1]], row=2, col=j+1)
            # FC
            fig.add_trace(go.Scatter(x=x_fc_cent[j,:,0], y=x_fc_cent[j,:,1], mode='lines', line=dict(width=2, color=color_list[j])), row=3, col=j+1)
            fig.update_xaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[0], max_coord[0]], row=3, col=j+1)
            fig.update_yaxes(showline=True, showgrid=True, linewidth=1, linecolor='black', zeroline=False, range=[-max_coord[1], max_coord[1]], row=3, col=j+1)
        fig.update_layout(showlegend=False, title="Comparison of the SRVF (first row), SRC (second row), Frenet curvatures (third row) geodesic paths.")
        fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
        fig.show()


    def plot_comparison_3D(self, path_x_srvf, path_x_src, path_x_fc, rotation):
        k = len(path_x_fc)
        x_srvf_cent, coord_srvf = self.center_3D(path_x_srvf, rotation)
        x_src_cent, coord_src = self.center_3D(path_x_src, rotation)
        x_fc_cent, coord_fc = self.center_3D(path_x_fc, rotation)
        x, y, z = np.max([coord_srvf[0], coord_src[0], coord_fc[0]]), np.max([coord_srvf[1], coord_src[1], coord_fc[1]]), np.max([coord_srvf[2], coord_src[2], coord_fc[2]])
        max_coord = [x, y, z]
        # color_list = px.colors.qualitative.Plotly
        color_list = sample_colorscale('Bluered_r', list(np.linspace(0, 1, k)))
        fig_srvf = make_subplots(rows=1, cols=k, specs=[[{"type": "scene"} for i in range(k)]],)
        fig_src = make_subplots(rows=1, cols=k, specs=[[{"type": "scene"} for i in range(k)]],)
        fig_fc = make_subplots(rows=1, cols=k, specs=[[{"type": "scene"} for i in range(k)]],)
        for i in range(k):
            fig_srvf.add_trace(go.Scatter3d(x=x_srvf_cent[i][:,0], y=x_srvf_cent[i][:,1], z=x_srvf_cent[i][:,2], mode="lines", line=dict(width=5, color=color_list[i])), row=1, col=i+1)
            fig_srvf.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[0], max_coord[0]]),
                                yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[1], max_coord[1]]),
                                zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[2], max_coord[2]]),), aspectmode='cube', row=1, col=i+1)
            fig_src.add_trace(go.Scatter3d(x=x_src_cent[i][:,0], y=x_src_cent[i][:,1], z=x_src_cent[i][:,2], mode="lines", line=dict(width=5, color=color_list[i])), row=1, col=i+1)
            fig_src.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[0], max_coord[0]]),
                                yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[1], max_coord[1]]),
                                zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[2], max_coord[2]]),), aspectmode='cube', row=1, col=i+1)
            fig_fc.add_trace(go.Scatter3d(x=x_fc_cent[i][:,0], y=x_fc_cent[i][:,1], z=x_fc_cent[i][:,2], mode="lines", line=dict(width=5, color=color_list[i])), row=1, col=i+1)
            fig_fc.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[0], max_coord[0]]),
                                yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[1], max_coord[1]]),
                                zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8, range=[-max_coord[2], max_coord[2]]),), aspectmode='cube', row=1, col=i+1)
        fig_srvf.update_layout(title="Geodesic SRVF 3D curves")
        fig_src.update_layout(title="Geodesic SRC 3D curves")
        fig_fc.update_layout(title="Geodesic FC 3D curves")
        fig_srvf.show()
        fig_src.show()
        fig_fc.show()


    # def comparison_geodesics(self, x0, x1, theta0, theta1, s0, s1, k, lam=1, rotation=False, param_theta_srvf={"h":0.02,"nb_basis":30,"hyperparam":[0.02, 1e-06, 1e-06]}):
    def comparison_geodesics(self, x0, x1, theta0, theta1, s0, s1, k, lam=1, rotation=False, param_theta_srvf={"h_deriv":0.02,"smoothing_param":0.001,"nb_basis":30,"h":None,"method":'extrinsic'}, full_plot=False):
        """ This function computes
                - the three geodesics (SRVF, Frenet curvatures, SRC)
                - the deform shape x1 to match x0 of the corresponding registration algorithm
                - the optimal warping function h and gamma
            and uniformely plot the results.
        """
        time = np.linspace(0,1,len(x0))

        # SRVF
        dist_srvf, path_x_srvf, path_q, R_opt, h_opt_srvf = self.geodesic_srvf(x0, x1, k)
        gamma_opt_srvf = self.h_to_gamma(h_opt_srvf, s0(time), s1(time))
        path_theta_srvf = self.estimate_theta_along_geodesic(path_x_srvf, param_theta_srvf["h_deriv"], param_theta_srvf["smoothing_param"], param_theta_srvf["nb_basis"], h=param_theta_srvf["h"], method=param_theta_srvf["method"])
        def_x1_srvf = self.deform_srvf(x0, h_opt_srvf)

        # Frenet curvatures
        dist_fc, path_x_fc, path_Q_fc, path_theta_fc, path_theta_fct_fc, h_opt_fc = self.geodesic_frenet_curvatures(theta0, theta1, s0(time), s1(time), time, k)
        gamma_opt_fc = self.h_to_gamma(h_opt_fc, s0(time), s1(time))
        def_x1_fc = x1

        # SRC
        dist_src, path_x_src, path_Q_src, path_theta_src, path_arc_s, path_psi, gamma_opt_src = self.geodesic_src(theta0, theta1, s0, s1, time, k, align=True, smooth=True, lam=lam)
        h_opt_src = self.gamma_to_h(gamma_opt_src, s0(time), s1(time))
        def_x1_src = self.deform_src(theta0, gamma_opt_src)

        # Results dic
        res_SRVF = {"dist" : dist_srvf, "path_x": path_x_srvf, "path_q": path_q, "path_theta": path_theta_srvf, "R_opt": R_opt, "h_opt": h_opt_srvf, "gamma_opt": gamma_opt_srvf, "deform_x1": def_x1_srvf}
        res_SRC = {"dist" : dist_src, "path_x": path_x_src, "path_Q": path_Q_src, "path_theta": path_theta_src, "path_arc_s": path_arc_s, "path_psi": path_psi, "h_opt": h_opt_src, "gamma_opt": gamma_opt_src, "deform_x1": def_x1_fc}
        res_FrenetCurvatures = {"dist" : dist_fc, "path_x": path_x_fc, "path_Q": path_Q_fc, "path_theta": path_theta_fc, "path_theta_fct":path_theta_fct_fc,  "h_opt": h_opt_fc, "gamma_opt": gamma_opt_fc, "deform_x1": def_x1_src}

        # Plot
        """ 3D """
        if self.dim == 3:
            self.plot_comparison_3D(path_x_srvf, path_x_src, path_x_fc, rotation)
        elif self.dim == 2:
            self.plot_comparison_2D(path_x_srvf, path_x_src, path_x_fc, rotation)

        """ Theta """
        path_s = np.array([time for i in range(k)])
        path_theta_srvf[-1] = path_theta_fc[-1]
        path_theta_srvf[0] = path_theta_fc[0]
        self.plot_comparison_theta([path_s, path_arc_s, path_s], [path_theta_srvf, path_theta_src, path_theta_fc])

        if full_plot:

            color_list = px.colors.qualitative.Plotly
            color_x0 = color_list[1]
            color_x1 = color_list[0]
            color_srvf = color_list[2]
            color_src = color_list[5]
            color_fc = color_list[6]

            """ Warping functions """
            # h
            fig = go.Figure(layout=go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.add_trace(go.Scatter(x=time, y=h_opt_srvf, mode='lines', name="SRVF", line=dict(width=2, color=color_srvf)))
            fig.add_trace(go.Scatter(x=time, y=h_opt_src, mode='lines', name="SRC", line=dict(width=2, color=color_src)))
            fig.add_trace(go.Scatter(x=time, y=h_opt_fc, mode='lines', name="FC", line=dict(width=2, color=color_fc)))
            fig.update_layout(title="Warping function h", height=600, width=600)
            fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black')
            fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black')
            fig.show()
            # gamma
            fig = go.Figure(layout=go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.add_trace(go.Scatter(x=time, y=gamma_opt_srvf, mode='lines', name="SRVF", line=dict(width=2, color=color_srvf)))
            fig.add_trace(go.Scatter(x=time, y=gamma_opt_src, mode='lines', name="SRC", line=dict(width=2, color=color_src)))
            fig.add_trace(go.Scatter(x=time, y=gamma_opt_fc, mode='lines', name="FC", line=dict(width=2, color=color_fc)))
            fig.update_layout(title="Warping function gamma", height=600, width=600)
            fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black')
            fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black')
            fig.show()

            """ Deform curves """
            if self.dim==3:
                center_def_x, _ = self.center_3D(np.array([x0, x1, def_x1_srvf, def_x1_src, def_x1_fc]), True)
                fig = go.Figure(layout=go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
                fig.add_trace(go.Scatter3d(x=center_def_x[0][:,0],y=center_def_x[0][:,1],z=center_def_x[0][:,2],name='X_0',mode='lines',line=dict(width=5,color=color_x0)))
                fig.add_trace(go.Scatter3d(x=center_def_x[1][:,0],y=center_def_x[1][:,1],z=center_def_x[1][:,2],name='X_1',mode='lines',line=dict(width=5,color=color_x1)))
                fig.add_trace(go.Scatter3d(x=center_def_x[2][:,0],y=center_def_x[2][:,1],z=center_def_x[2][:,2],name='SRVF',mode='lines',line=dict(width=5,color=color_srvf, dash='longdash')))
                fig.add_trace(go.Scatter3d(x=center_def_x[3][:,0],y=center_def_x[3][:,1],z=center_def_x[3][:,2],name='SRC',mode='lines',line=dict(width=5,color=color_src, dash='dashdot')))
                fig.add_trace(go.Scatter3d(x=center_def_x[4][:,0],y=center_def_x[4][:,1],z=center_def_x[4][:,2],name='FC',mode='lines',line=dict(width=5,color=color_fc, dash='dash')))
                fig.update_layout(legend=dict(orientation="h",yanchor="top",y=1.2,xanchor="right", x=1),
                                scene = dict(xaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),
                                            yaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),
                                            zaxis = dict(backgroundcolor="rgb(0, 0, 0)",gridcolor="grey",gridwidth=0.8,zeroline=False,showbackground=False,),),
                                            height=600, width=600, title="Deformed curves")
                fig.show()
            if self.dim==2:
                center_def_x, _ = self.center_2D(np.array([x0, x1, def_x1_srvf, def_x1_src, def_x1_fc]), True)
                fig = go.Figure(layout=go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
                fig.add_trace(go.Scatter(x=center_def_x[0][:,0],y=center_def_x[0][:,1], mode='lines', name="X_0", line=dict(width=2, color=color_x0)))
                fig.add_trace(go.Scatter(x=center_def_x[1][:,0],y=center_def_x[1][:,1], mode='lines', name="X_1", line=dict(width=2, color=color_x1)))
                fig.add_trace(go.Scatter(x=center_def_x[2][:,0],y=center_def_x[2][:,1], mode='lines', name="SRVF", line=dict(width=2, color=color_srvf, dash='longdash')))
                fig.add_trace(go.Scatter(x=center_def_x[3][:,0],y=center_def_x[3][:,1], mode='lines', name="SRC", line=dict(width=2, color=color_src, dash='dashdot')))
                fig.add_trace(go.Scatter(x=center_def_x[4][:,0],y=center_def_x[4][:,1], mode='lines', name="FC", line=dict(width=2, color=color_fc, dash='dash')))
                fig.update_layout(title="Deformed curves", height=600, width=600)
                fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False)
                fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False)
                fig.show()

        return res_SRVF, res_SRC, res_FrenetCurvatures


    def comparison_dist_geodesics(self, x0, x1, theta0, theta1, s0, s1, lam=1):
        """ 
            This function computes the three geodesic distances (SRVF, Frenet curvatures, SRC)
        """
        time = np.linspace(0,1,len(x0))
        # SRVF
        dist_srvf = self.dist_srvf(x0, x1)
        # Frenet curvatures
        dist_fc = self.dist_frenet_curvatures(theta0, theta1, time)
        # SRC
        dist_src = self.dist_src(theta0, theta1, s0, s1, time, align=True, smooth=True, lam=lam)

        return dist_srvf, dist_src, dist_fc
    

    def plot_karcher_means(self, mean_srvf, mean_src, mean_fc, rotation=True):

        if self.dim == 2:
            mean_cent_rot, _ = self.center_2D(np.array([mean_srvf, mean_src, mean_fc]), rotation)
            fig = make_subplots(rows=1, cols=3, shared_xaxes=False, shared_yaxes=False)
            fig.add_trace(go.Scatter(x=mean_cent_rot[0][:,0],y=mean_cent_rot[0][:,1], mode='lines', name="SRVF", line=dict(width=0.5, color='grey'), showlegend=False), row=1, col=1)
            fig.add_trace(go.Scatter(x=mean_cent_rot[0][:,0],y=mean_cent_rot[0][:,1], mode='lines', name="SRVF", line=dict(width=2.5, color=color_srvf, dash='dot')), row=1, col=1)
            fig.add_trace(go.Scatter(x=mean_cent_rot[1][:,0],y=mean_cent_rot[1][:,1], mode='lines', name="SRC", line=dict(width=0.5, color='grey'), showlegend=False), row=1, col=2)
            fig.add_trace(go.Scatter(x=mean_cent_rot[1][:,0],y=mean_cent_rot[1][:,1], mode='lines', name="SRC", line=dict(width=2.5, color=color_src, dash='dashdot')), row=1, col=2)
            fig.add_trace(go.Scatter(x=mean_cent_rot[2][:,0],y=mean_cent_rot[2][:,1], mode='lines', name="FC", line=dict(width=0.5, color='grey'), showlegend=False), row=1, col=3)
            fig.add_trace(go.Scatter(x=mean_cent_rot[2][:,0],y=mean_cent_rot[2][:,1], mode='lines', name="FC", line=dict(width=2.5, color=color_fc, dash='dash')), row=1, col=3)
            fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=1)
            fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=1)
            fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=2)
            fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=2)
            fig.update_xaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=3)
            fig.update_yaxes(showline=True, showgrid=False, linewidth=1, linecolor='black', zeroline=False, row=1, col=3)
            fig.update_layout(autosize=False,width=1400,height=500,)
            fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.show()

        elif self.dim == 3:
            mean_cent, _ = self.center_3D(np.array([mean_srvf, mean_src, mean_fc]), rotation)
            fig = make_subplots(rows=1, cols=3, specs=[[{"type": "scene"} for i in range(3)]],)
            fig.add_trace(go.Scatter3d(x=mean_cent[0][:,0], y=mean_cent[0][:,1],z=mean_cent[0][:,2], mode='lines', name="SRVF", line=dict(width=2,color='grey'),showlegend=False), row=1, col=1)
            fig.add_trace(go.Scatter3d(x=mean_cent[0][:,0], y=mean_cent[0][:,1],z=mean_cent[0][:,2], mode='lines', name="SRVF", line=dict(width=8,color=color_srvf, dash='dot'),showlegend=True), row=1, col=1)  
            fig.add_trace(go.Scatter3d(x=mean_cent[1][:,0], y=mean_cent[1][:,1],z=mean_cent[1][:,2], mode='lines', name="SRC", line=dict(width=2,color='grey'),showlegend=False), row=1, col=2)
            fig.add_trace(go.Scatter3d(x=mean_cent[1][:,0], y=mean_cent[1][:,1],z=mean_cent[1][:,2], mode='lines', name="SRC", line=dict(width=8,color=color_src, dash='dashdot'),showlegend=True), row=1, col=2)
            fig.add_trace(go.Scatter3d(x=mean_cent[2][:,0], y=mean_cent[2][:,1],z=mean_cent[2][:,2], mode='lines', name="FC", line=dict(width=2,color='grey'),showlegend=False), row=1, col=3)
            fig.add_trace(go.Scatter3d(x=mean_cent[2][:,0], y=mean_cent[2][:,1],z=mean_cent[2][:,2], mode='lines', name="FC", line=dict(width=8,color=color_fc, dash='dash'),showlegend=True), row=1, col=3)
            fig.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),), aspectmode='cube', row=1, col=1) 
            fig.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),), aspectmode='cube', row=1, col=2) 
            fig.update_scenes(dict(xaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            yaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),
                                            zaxis = dict(backgroundcolor="white",gridcolor="gray",gridwidth=0.8),), aspectmode='cube', row=1, col=3)        
            fig.update_layout(autosize=False,width=1300,height=500,)
            fig.update_layout(go.Layout(plot_bgcolor='rgba(0,0,0,0)'))
            fig.show()

        else:
            print("'plot_karcher_means' is only for dimension 2 or 3")
    
    

    
    
    
    