#!/usr/bin/env python3

# %%

import os
import os.path as path

import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import dblquad
from scipy.stats import norm

# %%

x_lim = 10  # Limit on the integration over x
eps = 1e-3  # Relative tolerance on the integral error


def sig_bar_ann(sig_ann):
    def integrand_ann(x, t):
        err_1 = (x - (1 - t)) ** 2
        err_2 = (x - (1 + t)) ** 2
        return 0.5 * (err_1 + err_2) * norm.pdf(x, loc=0, scale=sig_ann)

    # dblquad returns (value, error), hence the [0]
    return np.sqrt(dblquad(integrand_ann, 0, 1, -x_lim, x_lim, epsrel=eps)[0])


def sig_bar_snn(sig_1_snn, sig_2_snn):
    def integrand_snn(x, t):
        err = (x - t) ** 2
        sig_t_snn = (1 - t) * sig_1_snn + t * sig_2_snn
        return err * norm.pdf(x, loc=0, scale=sig_t_snn)

    # dblquad returns (value, error), hence the [0]
    return np.sqrt(dblquad(integrand_snn, 0, 1, -x_lim, x_lim, epsrel=eps)[0])


# sig_1_snn and sig_2_snn should be 1D arrays
def make_grid_snn(sig_1_snn, sig_2_snn):
    grid = np.full(shape=(sig_1_snn.size, sig_2_snn.size), fill_value=np.nan)
    for j, s_1 in enumerate(sig_1_snn):
        for k, s_2 in enumerate(sig_2_snn):
            if s_2 <= s_1:
                grid[j, k] = sig_bar_snn(s_1, s_2)
    return grid


# %%

fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(19.2, 4.8))
fig.set_tight_layout(True)

sig_snn = np.linspace(start=0.1, stop=2, num=20)
grid_snn = make_grid_snn(sig_snn, sig_snn)

for i, sig_ann_i in enumerate(np.linspace(0.2, 1.8, 5)):
    ax[i].imshow(grid_snn - sig_bar_ann(sig_ann_i),
                 cmap="bwr",
                 origin="lower",
                 extent=(0.05, 2.05, 0.05, 2.05))
    ticks = np.linspace(start=0.2, stop=2, num=10)
    ax[i].set(xticks=ticks,
              yticks=ticks,
              title=r"$\sigma_\mathrm{ANN} = " + "{:.1f}$".format(sig_ann_i))
    if i == 0:
        ax[i].set(xlabel=r"$\sigma_{2, \mathrm{SNN}}$",
                  ylabel=r"$\sigma_{1, \mathrm{SNN}}$")
    ax[i].axvline(sig_ann_i, color="black", linewidth=1)
    ax[i].axhline(sig_ann_i, color="black", linewidth=1)

base_dir = path.join("figures", "other")
os.makedirs(base_dir, exist_ok=True)
fig.savefig(path.join(base_dir, "random_walk.pdf"))
fig.show()
