import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def fig_make(m_rows, n_cols, flatten, **kwargs):
    """Gets figure for plotting with m x n axes."""
    figsize = plt.rcParams["figure.figsize"]
    figsize = (figsize[0] * n_cols, figsize[1] * m_rows)
    fig, axes = plt.subplots(m_rows, n_cols, figsize=figsize, squeeze=False, **kwargs)
    axes = [ax for axes in axes for ax in axes] if flatten else axes
    return fig, axes
def set_plot_style():
    """Sets default plotting styles for all plots."""
    # plt.rcParams["figure.figsize"] = [6.0, 4]
    plt.rcParams["figure.figsize"] = [3.0, 2]
    plt.rcParams["axes.linewidth"] = 1
    plt.rcParams["axes.grid"] = True
    plt.rcParams["grid.alpha"] = 0.4
    plt.rcParams["xtick.bottom"] = False
    plt.rcParams["ytick.left"] = False
    plt.rcParams["legend.edgecolor"] = "0.3"
    plt.rcParams["axes.xmargin"] = 0.025
    plt.rcParams["lines.linewidth"] = 1.25
    plt.rcParams["lines.markersize"] = 5.0
    plt.rcParams["font.size"] = 10
    plt.rcParams["axes.titlesize"] = 10
    plt.rcParams["legend.fontsize"] = 8
    plt.rcParams["legend.title_fontsize"] = 8
    plt.rcParams["xtick.labelsize"] = 7
    plt.rcParams["ytick.labelsize"] = 7

def dilation_vs_time():
    set_plot_style()
    # m, n = 1, 1
    # fig, axes = fig_make(m, n, True)
    ds=list(range(1,19))
    #ts=list(range(1,19))
    ts=[0.0006168842315673828, 0.0006783771514892579, 0.0006738805770874024, 0.0005538702011108398, 0.0005596113204956054,0.0005542564392089843, 0.0005533742904663086, 0.0005581331253051758, 0.0005504536628723144, 0.0005482935905456543,0.000562136173248291, 0.0005651140213012696, 0.0005467510223388671, 0.0005594706535339356, 0.0005618429183959961,0.0005541324615478515, 0.0005780768394470215, 0.0005727362632751465]
    plt.plot(ds,ts,label="hello")
    plt.ylabel('dilation rate')
    plt.xlabel('time')
    plt.legend(loc="lower right")
    plt.tight_layout(pad=0.3, h_pad=1.08, w_pad=1.08)
    # plt.savefig('plots/foo.pdf')
    plt.savefig('plots/foo2.pdf')
    plt.show()
def random_resize():
    val_sizes=[512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048]
    # 512_2048
    resize1=[70.22, 72.51, 74.63, 75.97, 76.71, 76.95, 76.99, 77.03, 76.83, 76.76, 76.13, 75.87, 75.15]

    # 400_1600
    resize2=[71.78, 74.39, 76.5, 77.38, 77.99, 78.04, 77.93, 77.76, 76.84, 76.09, 75.03, 73.28, 71.48]
    # resize2=[72.63, 75.37, 76.95, 77.71, 78.06, 78.25, 77.81, 77.53, 77.05, 76.5, 75.63, 74.53, 73.12]
    # resize2=[72.43, 75.08, 76.56, 77.67, 78.08, 78.35, 77.98, 77.76, 77.41, 76.37, 75.52, 74.22, 72.75]

    set_plot_style()
    plt.plot(val_sizes, resize1, label="[512,2048]")
    plt.plot(val_sizes, resize2, label="[400,1600]")
    plt.axvline(x=1024, label="interested val_size",c="r")
    plt.xlabel('val_size')
    plt.ylabel('mIOU')
    plt.legend(loc="lower right")
    plt.tight_layout(pad=0.5, h_pad=1.08, w_pad=1.08)
    ticks=[tick for i, tick in enumerate(val_sizes) if i % 2 == 0]
    plt.xticks(ticks=ticks)
    plt.savefig('plots/random_resize.pdf')
    plt.show()
def competitors_info():
    # name, test, fps, param(M), flop
    competitors=[
        ("CAS",70.5,108,0,0),
        ("DFANet A",71.3,100,7.8,3.4),
        ("FasterSeg",71.5,163.9,4.4,28.2),
        ("GAS",71.8,108.4,0,0),
        ("MobileNetV3",72.6,0,1.51,9.74),
        ("HMSeg",74.3,83.2,0,0),
        ("TD4-Bise18",74.9,47.6,0,0),
        ("BiSeNetV2-L",75.3,47.3,4.586,138.993),
        ("DF2-Seg2",75.3,56.3,0,0),
        ("SwiftNetRN-18",75.5,39.9,11.8,104.0),
        ("HyperSeg-M",75.8,36.9,10.311,8.404),
        ("CABiNet",75.9,76.5,2.64,12),
        ("FC-HarDNet-70",76.0,53,4.119,35.628),# check fps
        ("STDC2-Seg75",76.8,97.0,16.079,54.935),
        ("MSFNet",77.1,41,0,96.8),
        ("SFNet(DF2)",77.8,53,17.876,80.35),
        ("DDRNet-23",79.4,37.1,20.1,143.1),
        ("RegSeg",78.3,30,3.335,39.077),
        # ("RegSeg-M",79.5,20,6.132,87.452)
    ]
    return competitors
def miou_vs_fps():
    competitors=competitors_info()
    names=[]
    mious=[]
    all_fps=[]
    text_position=[]
    text_offsets={
        # "HyperSeg-M":(-0.5,-1),
        # "CABiNet":(-0.2,0.3),
        # "SFNet(DF2)":(-1.2,0.2),
        # "FC-HarDNet-70":(0.2,-0.3),
        # "BiSeNetV2-L":(0,-0.5),
        # "STDC2-Seg75":(-0.5,0.2),
        # "FasterSeg":(0,0.2),
        # "DFANet A": (2/7,-0.2),
        # "DDRNet-23":(-4,-0.5),
        # "SwiftNetRN-18":(-10/7,0.2),
        # "RegSeg":(0.1,0.1),
    }
    for competitor in competitors:
        if competitor[1]!=0 and competitor[2]!=0:
            name,miou,fps=competitor[0],competitor[1],competitor[2]
            names.append(name)
            mious.append(miou)
            all_fps.append(fps)
            if name in text_offsets:
                dx,dy=text_offsets[name]
                text_position.append((fps+dx,miou+dy))
            else:
                text_position.append((fps,miou))
    set_plot_style()
    plt.rcParams["figure.figsize"] = [4.5, 3]
    plt.scatter(all_fps[:-1],mious[:-1])
    plt.scatter(all_fps[-1:],mious[-1:],c="r")
    for i,name in enumerate(names):
        plt.annotate(name,text_position[i])
    plt.xlabel('FPS')
    plt.ylabel('mIOU')
    #plt.legend(loc="lower right")
    plt.tight_layout(pad=0.3, h_pad=1.08, w_pad=1.08)
    plt.savefig('plots/miou_vs_fps.pdf')
    plt.show()
def miou_vs_params():
    competitors=competitors_info()
    names=[]
    mious=[]
    all_params=[]
    text_position=[]
    text_offsets={
        "HyperSeg-M":(-0.5,-1),
        "CABiNet":(-0.2,0.3),
        "SFNet(DF2)":(-1.2,0.2),
        "FC-HarDNet-70":(0.2,-0.3),
        "BiSeNetV2-L":(0,-0.5),
        "STDC2-Seg75":(-0.5,0.2),
        "FasterSeg":(0,0.2),
        "DFANet A": (2/7,-0.2),
        "DDRNet-23":(-4,-0.5),
        "SwiftNetRN-18":(-10/7,0.2),
        "RegSeg":(0.1,0.1),
    }
    for competitor in competitors:
        if competitor[1]!=0 and competitor[3]!=0:
            name,miou,params=competitor[0],competitor[1],competitor[3]
            names.append(name)
            mious.append(miou)
            all_params.append(params)
            if name in text_offsets:
                dx,dy=text_offsets[name]
                text_position.append((params+dx,miou+dy))
            else:
                text_position.append((params,miou))
    set_plot_style()
    plt.rcParams["figure.figsize"] = [4.5, 3]
    plt.scatter(all_params[:-1],mious[:-1])
    plt.scatter(all_params[-1:],mious[-1:],c="r")
    for i,name in enumerate(names):
        plt.annotate(name,text_position[i])
    plt.xlabel('Params (M)')
    plt.ylabel('mIOU')
    #plt.legend(loc="lower right")
    plt.tight_layout(pad=0.3, h_pad=1.08, w_pad=1.08)
    plt.savefig('plots/miou_vs_params.pdf')
    plt.show()
def miou_vs_flops():
    competitors=competitors_info()
    names=[]
    mious=[]
    all_flops=[]
    text_position=[]
    text_offsets={
        "HyperSeg-M":(-4,-0.6),
        "CABiNet":(-4,0.2),
        "BiSeNetV2-L":(-25,-0.5),
        "STDC2-Seg75":(-5,0.2),
        "FasterSeg":(0,0.2),
        "DFANet A": (2,-0.2),
        "DDRNet-23":(-30,-0.5),
        "SwiftNetRN-18":(-10,0.2),
        "RegSeg":(1,0.1),
    }
    for competitor in competitors:
        if competitor[1]!=0 and competitor[4]!=0:
            name,miou,flops=competitor[0],competitor[1],competitor[4]
            names.append(name)
            mious.append(miou)
            all_flops.append(flops)
            if name in text_offsets:
                dx,dy=text_offsets[name]
                text_position.append((flops+dx,miou+dy))
            else:
                text_position.append((flops,miou))
    set_plot_style()
    plt.rcParams["figure.figsize"] = [4.5, 3]
    plt.scatter(all_flops[:-1], mious[:-1])
    plt.scatter(all_flops[-1:], mious[-1:], c="r")
    for i,name in enumerate(names):
        plt.annotate(name,text_position[i])
    plt.xlabel('GFLOPs')
    plt.ylabel('mIOU')
    #plt.legend(loc="lower right")
    plt.tight_layout(pad=0.3, h_pad=1.08, w_pad=1.08)
    plt.savefig('plots/miou_vs_flops.pdf')
    plt.show()
def reproducibility():
    def f(v):
        exclude_classes=[14,15,16]
        others=[]
        excluded=[]
        for i,x in enumerate(v):
            if i in exclude_classes:
                excluded.append(x)
            else:
                others.append(x)
        mIOU=sum(v)/len(v)
        mIOU_reduced=sum(others)/len(others)
        return excluded,round(mIOU,2),round(mIOU_reduced,2)
    all_ious=[
        [97.95, 83.61, 91.95, 55.4, 58.5, 61.45, 66.12, 75.25, 92.01, 63.24, 94.74, 78.92, 59.29, 94.27, 68.11, 73.55, 35.77, 58.0, 73.41],
        [97.99, 83.82, 92.11, 54.8, 60.36, 61.55, 66.22, 75.5, 92.09, 64.73, 94.72, 79.16, 59.18, 94.38, 77.35, 79.05, 53.18, 58.58, 73.2],
        [98.12, 84.58, 92.13, 56.12, 59.83, 61.33, 66.51, 75.2, 92.19, 64.07, 94.83, 79.14, 58.5, 94.3, 73.17, 77.78, 58.77, 58.87, 72.87],
        [98.04, 84.09, 92.29, 58.96, 60.11, 61.38, 65.85, 75.47, 92.13, 63.78, 94.92, 79.16, 58.64, 94.36, 71.31, 81.73, 74.27, 54.33, 72.98],
    ]
    for ious in all_ious:
        print(f(ious))
if __name__=="__main__":
    #dilation_vs_time()
    #reproducibility()
    # miou_vs_params()
    # miou_vs_flops()
    miou_vs_fps()
    # random_resize()
