import json
import os
import numpy as np
import matplotlib.pyplot as plt

def steep_transform(r):
    a=r%1
    b=r-a
    if a<0.45:
        slope=0.2/0.45
        bias=-b*slope
    elif a<0.55:
        slope=0.6/0.1
        bias=0.2-(0.45+b)*slope
    else:
        slope=0.2/0.45
        bias=0.8-(0.55+b)*slope
    return slope*r+bias+b

def extract_info(filename):
    losses=[]
    all_dilation_rates=[]
    with open(filename) as f:
        lines=f.readlines()
        lines=[line.strip().split(": ") for line in lines]
        for line in lines:
            if line[0]=="loss":
                loss=json.loads(line[1])
                losses.append(loss)
            if line[0]=="dilation rates":
                dilation_rates=json.loads(line[1])
                all_dilation_rates.append(dilation_rates)
    return losses,all_dilation_rates


def f(filename):
    # "logs/learn3_decoder30_500_epochs_log.txt",
    steep=False
    losses,all_dilation_rates=extract_info(filename)
    losses=[loss*10 for loss in losses]
    plt.plot(losses,label="loss")
    all_ds=[[]for _ in range(2*len(all_dilation_rates[0]))]
    for dilation_rates in all_dilation_rates:
        for i in range(len(dilation_rates)):
            r0=dilation_rates[i][0]
            r1=dilation_rates[i][1]
            if steep:
                r0,r1=steep_transform(r0),steep_transform(r1)
            all_ds[2*i].append(r0)
            all_ds[2*i+1].append(r1)

    for i,ds in enumerate(all_ds):
        plt.plot(ds,label=f"{i}")
    plt.legend(loc="lower right")
    plt.show()
def g(filenames):
    from fov import get_fov
    for filename in filenames:
        losses,all_dilation_rates=extract_info(filename)
        fovs=[]
        for dilation_rates in all_dilation_rates:
            fovs.append(get_fov(dilation_rates))
        plt.plot(fovs,label=filename)
    # plt.legend(loc="lower right")
    plt.ylabel("field-of-view",fontsize=20)
    plt.xlabel("epochs",fontsize=20)
    plt.title("field-of-view vs epochs for DNAS",fontsize=20)
    plt.show()
if __name__=="__main__":
    g(
        [
         #    "mapillary_logs/mapillary_L1_decoder26_180_epochs_768_res_run1.txt",
         # "mapillary_logs/mapillary_L1_decoder26_180_epochs_run1.txt",
         # "logs/L1_decoder26_1000_epochs_768_res_run1.txt",
         # "logs/L1_decoder26_1000_epochs_run1.txt",
        "logs/learn2_decoder26_500_epochs_1024_crop_run1.txt"
        ]
    )
