import json
import os
import numpy as np

def calc_miou(ious):
    best_miou=0
    best_iou=0
    for iou in ious:
        miou=sum(iou)/len(iou)
        if miou > best_miou:
            best_miou=miou
            best_iou=iou
    return best_miou,best_iou
def calc_miou_reduced(ious,exclude_indices):
    best_miou=0
    for iou in ious:
        iou=[iou[i] for i in range(len(iou)) if i not in exclude_indices]
        miou=sum(iou)/len(iou)
        if miou > best_miou:
            best_miou=miou
    return best_miou

def print_dict(dic):
    for k,v in dic.items():
        if "trainval" in k:
            continue

        # if "exp48_decoder26_1000_epochs" not in k and "ddrnet_1000_epochs" not in k:
        #     continue
        # if "exp65" not in k:
        #     continue
        print(k)
        print(v[0])
    for k, v in dic.items():
        print(v[1])

def get_directory_logs(log_dir, filter_func=None, exclude_indices=None):
    if exclude_indices is None:
        exclude_indices = []
    files=[os.path.join(log_dir,filename) for filename in os.listdir(log_dir) if ".txt" in filename]
    files=sorted(files)
    dic={}
    for filename in files:
        ious=extract_ious(filename,filter_func)
        if len(ious)==0:
            continue
        best_miou,best_iou=calc_miou(ious)
        best_miou_reduced=calc_miou_reduced(ious,exclude_indices)
        name=filename.split("/")[-1]
        dic[name]=(best_miou, best_miou_reduced,best_iou)
    miou_dic={k: (v[0],v[2]) for k, v in sorted(dic.items(), key=lambda item: item[1][0], reverse=True)}
    miou_reduced_dic={k: (v[1],v[2]) for k, v in sorted(dic.items(), key=lambda item: item[1][1], reverse=True)}
    if len(exclude_indices)>0:
        print("miou reduced")
        print_dict(miou_reduced_dic)
    else:
        print("miou original")
        print_dict(miou_dic)
    return miou_dic,miou_reduced_dic
def extract_ious(filename,filter_func=None):
    ious=[]
    with open(filename) as f:
        lines=f.readlines()
        config=eval(lines[0])
        if filter_func and filter_func(config):
            return ious

        lines=[line.strip().split(": ") for line in lines]
        for line in lines:
            if line[0]=="IoU" or line[0]=="IOU":
                line[1]=line[1].replace("'","")
                #print(line[1])
                iou=json.loads(line[1])
                ious.append(iou)
    return ious

def mean_and_std(v):
    mean=np.mean(v)
    std=np.std(v)
    mean=np.round(mean,3)
    std=np.round(std,3)
    # print(f"{mean} +- {std}")
    return f"{mean} +- {std}"
def comparison_helper(log_dic, class_dic, exclude_indices=None):
    if exclude_indices is None:
        exclude_indices = []
    ious_dic = {}
    for name, filename in log_dic.items():
        miou_dic, miou_reduced_dic = get_directory_logs(filename)
        v = list(x[1] for x in miou_dic.values())
        w=list(x[1] for x in miou_reduced_dic.values())
        ious_dic[name] = v
        mious = [np.mean(x) for x in v]
        mious_reduced=[np.mean([y for i,y in enumerate(x) if i not in exclude_indices]) for x in v]
        print(f"{name}:", mean_and_std(mious))
        print(f"{name} reduced:", mean_and_std(mious_reduced))
    for i in range(len(class_dic)):
        print("class:", i, class_dic[i])
        for name, v in ious_dic.items():
            w = [x[i] for x in v]
            print(f"{name}:", mean_and_std(w))
def cityscapes_comparisons():
    print("cityscapes_comparisons")
    from datasets.cityscapes import get_classes
    class_dic = get_classes()
    log_dic={
        "exp48":"logs/exp48_5runs",
        "exp55": "logs/exp55_5runs",
        "ddrnet23":"logs/ddrnet23_5runs",
    }
    comparison_helper(log_dic,class_dic)
def mapillary_comparisons():
    print("mapillary_comparisons")
    from datasets.mapillary import get_classes
    ious = [0.0, 0.0, 57.68, 58.66, 63.16, 56.59, 50.9, 45.04, 39.82, 18.31,
            22.22, 45.77, 49.91, 87.97, 43.31, 70.61, 76.67, 86.43, 41.84,
            66.81, 46.77, 50.41, 0.0, 69.53, 57.07, 48.28, 4.99, 97.77, 76.83,
            68.69, 88.77, 72.93, 17.02, 22.26, 5.05, 45.31, 29.76, 0.0, 20.38,
            36.26, 2.43, 43.12, 4.4, 0.0, 37.03, 40.48, 52.69, 44.16, 60.96,
            36.36, 66.61, 43.94, 47.44, 16.69, 73.89, 89.68, 0.0, 55.74, 46.28,
            22.28, 6.71, 67.39, 8.41, 68.79, 91.75]
    exclude_indices=[i for i,iou in enumerate(ious) if iou<=30]
    exclude_indices=[]
    class_dic = get_classes("mapillary_dataset")
    log_dic={
        "exp48":"mapillary_logs/exp48_5runs",
        "exp55": "mapillary_logs/exp55_3runs",
        "ddrnet23":"mapillary_logs/ddrnet_5runs",
        "L1": "mapillary_logs/L1_3runs",
        "L6": "mapillary_logs/L6_3runs",
    }
    comparison_helper(log_dic,class_dic,exclude_indices)
def reproducibility():
    print("reproducibility")
    get_directory_logs("training_log/reproducibility",exclude_indices=[14,15,16])
    print()
def random_resizes_random_crops():
    print("random_resizes_random_crops")
    get_directory_logs("training_log/random_resizes_random_crops",exclude_indices=[14,15,16])
    print()
def training_techniques():
    print("training_techniques")
    get_directory_logs("training_log/training_techniques",exclude_indices=[14,15,16])
    print()
def logs():
    print("logs")
    get_directory_logs("logs")
    print()
def camvid_5runs():
    # All 5 runs can be found inside training_log/camvid/camvid_exp48_decoder26_200_epochs_288_1152_resize_log.txt
    w=[80.86231578480114,80.9288031838157,81.02897435968572,80.83710479736328,80.86533043601297]
    print("camvid 5 runs")
    print(mean_and_std(w))
    print()
def backbone_ablation_studies():
    print("backbone ablation studies")
    get_directory_logs("training_log/backbone_ablation_studies",exclude_indices=[14,15,16])
    print()
def decoder_ablation_studies():
    print("decoder ablation studies")
    get_directory_logs("training_log/decoder_ablation_studies",exclude_indices=[14,15,16])
    print()
if __name__=="__main__":
    cityscapes_comparisons()
    reproducibility()
    backbone_ablation_studies()
    decoder_ablation_studies()
    random_resizes_random_crops()
    training_techniques()
    camvid_5runs()
