#!/usr/bin/env python3
import collections
import argparse
from pathlib import Path

import scipy
import scipy.signal
import scipy.interpolate
import scipy.optimize
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches

# Signal generation
def gausskernel(std):
    """
    Compute a gausskernel with sigma = std
    number of coefficients is set automatically
    """
    N = np.round(2.0*std*np.sqrt(np.log(100)/0.5)+1.0)
    xc = np.linspace(1,N,N)-(N+1.0)/2.0
    gf = np.exp(-(xc/std)**2/2.0)
    
    return gf/np.sum(gf)

def random_signal(*, Ns=10000, fs=100):
    d0 = np.random.rand(Ns)-0.5
    d1 = np.convolve(d0,gausskernel(7.0),'same')
    d2 = np.convolve(d1,gausskernel(15.0),'same')
    d3 = np.convolve(d2,gausskernel(25.0),'same')

    d = np.zeros_like(d0)
    d[:Ns//4]=d3[:Ns//4]
    d[Ns//4:Ns//2]=d2[Ns//4:Ns//2]
    d[Ns//2:3*Ns//4]=d1[Ns//2:3*Ns//4]
    d[3*Ns//4:]=d2[3*Ns//4:]
    return d
    
    
# Expected variance of residuals after fit
def spline_interpolation_response(freqs, spline_dt):
    """Cubic B-spline interpolation frequency response according to
    
    Z. Mihajlovic, A. Goluban, and M. Zagar. Frequency Domain Analysis of 
    B-Spline Interpolation. ISIE’99
    """
    sinc = lambda x: 1.0 if x == 0 else np.sin(np.pi*x)/(np.pi*x)
    sinc = np.vectorize(sinc)
    def H_func(w):
        a = 3 * sinc(w/(2*np.pi))**4
        b = 2 + np.cos(w)
        return a / b

    return H_func(2*np.pi*freqs * spline_dt)

def estimate_variance_of_fit(signal, sample_rate, spline_dt, noise_std, *, spectrum=None):
    Xhat = spectrum if spectrum is not None else np.fft.fft(signal)
    nsample = len(signal)
    freqs = np.fft.fftfreq(nsample, d=1/sample_rate)
    
    H = spline_interpolation_response(freqs, spline_dt)
    
    E = (1 - H) * Xhat
    N = np.sqrt(nsample) * noise_std # Noise DFT-bins
        
    var_e = np.mean(np.abs(E)**2) / len(E)
    var_f = np.mean(np.abs(H*N)**2) / len(H)
    
    var_r = var_e + var_f    
    return var_r
        

def generate_data():
    sample_rate = 50
    signal_orig = 100 * random_signal(fs=sample_rate, Ns=10000) #random_signal()

    # Add noise
    noise_std = 0.02 * np.max(signal_orig)    

    signal = signal_orig + np.random.normal(scale=noise_std, size=signal_orig.shape)

    times = np.arange(len(signal)) / sample_rate

    freqs = np.fft.fftfreq(len(signal), d=1/sample_rate)
    SIGNAL = np.fft.fft(signal)

    dts = np.linspace(1/sample_rate, 20/sample_rate, num=100)
    results = collections.defaultdict(list)

    for spline_dt in dts:
        # Define knot placements
        duration = times[-1] - times[0]
        num_knots = int(np.ceil(duration / spline_dt))
        t = np.arange(num_knots) * spline_dt
        # Skip first and last sample to uphold Schoenberg-Whitney
        t = t[2:-2]

        # Only try on samples within spline
        valid = (times > t[0]) & (times < t[-1])
        testx = times[valid]
        actualy = signal_orig[valid]
        
        # Fit spline and compute true variance of residual, and cost of fit
        spl = scipy.interpolate.LSQUnivariateSpline(times, signal, t)
        testy = spl(testx)
        res = testy - actualy # Error compared to noise-free signal
        s = np.std(res)
        c = np.linalg.norm(res)
        results['No-noise'].append(s)
        
        noisey = signal[valid]
        res = testy - noisey # Residual, error compared to measurements
        s = np.std(res)
        c = np.linalg.norm(res)
        results['Actual'].append(s)
        
        # Compute the expected variance using the frequency functions
        var_exp = estimate_variance_of_fit(signal, sample_rate, spline_dt, noise_std, spectrum=SIGNAL)                                 
        s_exp = np.sqrt(var_exp)
        results['Predicted'].append(s_exp)
            
    data = {
        'signal_true': signal_orig,
        'signal': signal,
        'times': times,
        'dts': dts,
        'std_est': np.array(results['Predicted']),
        'std_actual': np.array(results['Actual']),
        'std_nonoise': np.array(results['No-noise']),
        'noise_std': noise_std,
    }
    
    return data
            
##### Begin plot

if __name__ == "__main__":
    print('Generating data')
    data = generate_data()    
    
    # Figure and axes
    gs = plt.GridSpec(2, 3)
    fig = plt.figure()
    ax_sig = fig.add_subplot(gs[0,:2])
    ax_part = fig.add_subplot(gs[0, 2])
    ax_std = fig.add_subplot(gs[1,:])
    
    # Signal plot
    ax_sig.plot(data['times'], data['signal'])
    ax_sig.set(ylabel='Signal')
    ax_sig.grid(which='both')
    
    # Part plot
    a, b = 300, 600
    part_times = data['times'][a:b]
    part_signal = data['signal'][a:b]
    ax_part.plot(part_times, part_signal)
    ax_part.plot(part_times, data['signal_true'][a:b], linewidth=2, color='k', alpha=0.5)
    ax_part.grid(which='both')
    
    # Rectagular box
    rect_color = 'r'
    t0 = part_times[0]
    dt = part_times[-1] - t0
    h0 = np.max(part_signal) - np.min(part_signal)
    hmargin = 0.2
    h = h0 * (1 + 2*hmargin)
    y0 = np.min(part_signal) - hmargin * h0
    rect = matplotlib.patches.Rectangle((part_times[0], y0), dt, h,
                                        transform=ax_sig.transData,
                                        color=rect_color,
                                        fill=False,
                                        zorder=10)
    ax_sig.add_patch(rect)
    
    # Arrow connecting to part plot
    start_data = (t0 + dt, y0 + h)
    fig_trf = fig.transFigure.inverted()
    start_fig = fig_trf.transform(ax_sig.transData.transform(start_data))
    end_fig = fig_trf.transform(ax_part.transData.transform(start_data))
        
    end_axis = (0, 0.8)
    arrow = matplotlib.patches.ConnectionPatch(
        start_data, end_axis,
        coordsA="data",
        coordsB="axes fraction",
        axesA=ax_sig,
        axesB=ax_part,
        arrowstyle='-|>',
        zorder=100,
        linewidth=1.5,
        color=rect_color,
        connectionstyle="arc3,rad=-0.17"
    )
    ax_sig.add_artist(arrow)    
    
    # Variance plot
    plot_data = {
        'std_actual': (r'$\sigma_r$', 1),
        'std_nonoise': (r'$\sigma_{r_0}$', 2),
        'std_est': ('Predicted', 3)
    }

    for key in ['std_est', 'std_actual', 'std_nonoise']:
        label, zorder = plot_data[key]
        ax_std.plot(data['dts'], data[key], label=label, zorder=zorder, linewidth=2)

    ax_std.axhline(data['noise_std'], 
                   color='k', 
                   linestyle='--', 
                   label=r'$\sigma_n$',
                   zorder=0)
            
    
    ax_std.legend(ncol=2,loc='upper left')
    ax_std.set(ylabel='stddev',
               xlabel=r'$\Delta t$')
    ax_std.grid(which='both')
    
    fig.subplots_adjust(
        top=0.96,
        bottom=0.16,
        left=0.14,
        right=0.98,
        hspace=0.3,
        wspace=0.27        
    )
    
    plt.show()
