#!/usr/bin/env python3

# %%

import os.path as path
import pickle

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

# %%

color = "#3233AF"

rc("text", usetex=True)
rc("font", family="sans-serif")


# %%

def add_plots(pickle_filename, plots, split, name, title=False):
    with open(path.join("results", "optimize", pickle_filename), "rb") as pickle_file:
        results = pickle.load(pickle_file)

    l1 = results[0]["loss_history_" + split][0]
    l2 = results[2]["loss_history_" + split][-1]

    norm = (np.array(results[0]["loss_history_{}".format(split)]) - l2) / (l1 - l2)
    plots[0].plot(norm, color=color)
    plots[0].set_xlim(0, 100)
    plots[0].set_xticks([])
    plots[0].set_yticks([])
    if title:
        plots[0].set_title("Phase 1\n100 Iterations")
    plots[0].set_ylabel(name)

    norm = (np.array(results[1]["loss_history_{}".format(split)]) - l2) / (l1 - l2)
    plots[1].plot(norm, color=color)
    plots[1].set_xlim(0, 1000)
    plots[1].set_xticks([])
    if title:
        plots[1].set_title("Phase 2\n1000 Iterations")

    norm = (np.array(results[2]["loss_history_{}".format(split)]) - l2) / (l1 - l2)
    plots[2].plot(norm, color=color)
    plots[2].set_xlim(0, 10000)
    plots[2].set_xticks([])
    if title:
        plots[2].set_title("Phase 3\n10000 Iterations")


# %%

for split in "opt", "val":
    fig, subplots = plt.subplots(4, 3, figsize=(6, 6), sharey=True)

    add_plots(
        "dense_mnist_o-1.0_n.p",
        subplots[0], split, "MNIST Dense", title=True)
    add_plots(
        "conv_mnist_o-0.6_n.p",
        subplots[1], split, "MNIST Conv")
    add_plots(
        "lite_cifar10_o-0.2_n.p",
        subplots[2], split, "CIFAR-10")
    add_plots(
        "lite_cifar100_o-0.3_n.p",
        subplots[3], split, "CIFAR-100")

    fig.savefig(path.join("results", "opt_convergence_{}.pdf".format(split)))
    fig.show()
