import scipy.optimize, scipy.spatial, math
import numpy as np

class Gaussian3D:
    def args(self, amplitude, offset, xo, yo, zo, s0, s1, s2, r0, r1, r2):
         mean = np.asarray([xo, yo, zo])

         amplitude = np.square(amplitude)

         S = np.square(np.diag([s0, s1, s2]))
         R = scipy.spatial.transform.Rotation.from_euler("XYZ", [r0, r1, r2]).as_matrix()
         cov = R @ S @ R.T

         return amplitude, offset, mean, cov

    def call(self, p, *args):
        amplitude, offset, mean, cov = self.args(*args)

        p = p - mean[:, np.newaxis]
        inv_cov = np.linalg.inv(cov)

        exp = -0.5 * (p.T[:, np.newaxis, :] @ inv_cov[np.newaxis, :, :] @ p.T[:, :, np.newaxis])
        exp = exp[:, 0, 0]

        g = offset + amplitude * np.exp(exp)
        return g

    def fit(self, x, y, initial_params=None):
        if initial_params is None:
            initial_params = (1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0)

        popt, pcov = scipy.optimize.curve_fit(self.call, x.T, y, p0=initial_params, maxfev=10000, ftol=1e-6, xtol=1e-6)

        amplitude, offset, mean, cov = self.args(*popt)

        return mean, cov

class Gaussian3DConstantMean:
    def __init__(self, mean):
        self.mean = mean

    def args(self, amplitude, offset, s0, s1, s2, r0, r1, r2):
         amplitude = np.square(amplitude)

         S = np.square(np.diag([s0, s1, s2]))
         R = scipy.spatial.transform.Rotation.from_euler("XYZ", [r0, r1, r2]).as_matrix()
         cov = R @ S @ R.T

         return amplitude, offset, cov

    def call(self, p, *args):
        amplitude, offset, cov = self.args(*args)

        p = p - self.mean[:, np.newaxis]
        inv_cov = np.linalg.inv(cov)

        exp = -0.5 * (p.T[:, np.newaxis, :] @ inv_cov[np.newaxis, :, :] @ p.T[:, :, np.newaxis])
        exp = exp[:, 0, 0]

        g = offset + amplitude * np.exp(exp)
        return g

    def fit(self, x, y, initial_params=None):
        if initial_params is None:
            initial_params = (1, 0, 1, 1, 1, 0, 0, 0)

        try:
            popt, pcov = scipy.optimize.curve_fit(self.call, x.T, y, p0=initial_params, maxfev=3000, ftol=1e-6, xtol=1e-6)
            perr = np.sqrt(np.diag(pcov))

            y_pred = self.call(x.T, *popt)

            amplitude, offset, cov = self.args(*popt)

            return cov, {
                "gaussian_mse": np.mean((y_pred - y) ** 2),
                "gaussian_mean_translation_std": math.sqrt(perr[3] + perr[4]),
            }
        except Exception as e:
            return None, None
