import matplotlib
import matplotlib.pyplot as plt
from options import plot_args_parser
import pickle
import numpy as np
import math
import os
from sklearn.manifold import TSNE
from torchvision import datasets, transforms

from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
from sampling import Dirichlet_noniid
from numpy.random import RandomState
from collections import Counter
#matplotlib.use('Agg')


def get_plot_dataset(args,data_dir=None,seed=None):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    if data_dir is None:
        data_dir = './data/'
    rs = RandomState(seed)
    if args.dataset == 'cifar':
        args.num_classes = 10
        if data_dir is None:
            data_dir = './data/cifar10'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users,rs)
            user_groups_test = cifar_iid(test_dataset,args.num_users,rs)
        else:
            # Sample Non-IID user data from Mnist
            if args.alpha is not None:
                # partition_sizes = [1.0 / args.num_users for _ in range(args.num_users)]
                # partition = DataPartitioner(train_dataset, partition_sizes, 0, isNonIID=True, alpha=args.alpha,
                #                             dataset=args.dataset, print_f=50)
                # user_groups = partition.partitions
                # user_groups_test = None
                user_groups,_ = Dirichlet_noniid(train_dataset, args.num_users,args.alpha,rs)
                user_groups_test,_ = Dirichlet_noniid(test_dataset, args.num_users,args.alpha,rs)
            elif args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users,args.shards_per_client,rs)
                user_groups_test = cifar_noniid(test_dataset, args.num_users,args.shards_per_client,rs)

    elif args.dataset == 'mnist' or args.dataset == 'fmnist':
        args.num_classes = 10
        # if args.dataset == 'mnist':
        # else:
        #     data_dir = './data/fmnist/'
        if data_dir is None:
            data_dir = './data/'
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])
        if args.dataset == 'mnist':
            train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                        transform=apply_transform)

            test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                        transform=apply_transform)
        else:
            train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,
                                        transform=apply_transform)

            test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True,
                                        transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users,rs)
            user_groups_test = mnist_iid(test_dataset,args.num_users,rs)
        else:
            # Sample Non-IID user data from Mnist
            if args.alpha is not None:
                # partition_sizes = [1.0 / args.num_users for _ in range(args.num_users)]
                # partition = DataPartitioner(train_dataset, partition_sizes, 0, isNonIID=True, alpha=args.alpha,
                #                             dataset=args.dataset, print_f=50)
                # user_groups = partition.partitions
                # user_groups_test = None
                user_groups,_ = Dirichlet_noniid(train_dataset, args.num_users,args.alpha,rs)
                user_groups_test,_ = Dirichlet_noniid(test_dataset, args.num_users,args.alpha,rs)
            elif args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users,rs)
                user_groups_test = mnist_noniid_unequal(test_dataset, args.num_users,rs)
            else:
                # Chose euqal splits for every user
                # if args.shards_per_client == 1:
                    # user_groups = mnist_noniid_gpr(train_dataset, args.num_users)
                # else:
                # user_groups = mnist_noniid(train_dataset, args.num_users)
                user_groups = mnist_noniid(train_dataset, args.num_users,args.shards_per_client,rs)
                user_groups_test = mnist_noniid(test_dataset,args.num_users,args.shards_per_client,rs)
    
    elif args.dataset == 'shake':
        args.num_classes = 80
        data_dir = './data/shakespeare/'
        trainx = torch.tensor([], dtype = torch.uint8)
        trainy = torch.tensor([], dtype = torch.uint8)
        testx = torch.tensor([], dtype = torch.uint8)
        testy = torch.tensor([], dtype = torch.uint8)
        user_groups_test={}
        try:
            trainx = torch.load(data_dir+'train/xdata.pt')
            trainy = torch.load(data_dir+'train/ydata.pt')
            user_groups = torch.load(data_dir+'train/user_groups.pt')
            testx  = torch.load(data_dir+'test/xdata.pt')
            testy = torch.load(data_dir+'test/ydata.pt')
            
        except: 
            # prepare training set
            user_groups = {}
            start = 0
            with open(data_dir+'/train/data.json', 'r') as inf:
                data = json.load(inf)
            for n,u in enumerate(tqdm(data['users'])):
                temp_x = process_x(data['user_data'][u]['x'])
                temp_y = process_y(data['user_data'][u]['y'])
                # print(temp_y[0])
                trainx = torch.cat((trainx, torch.tensor(temp_x, dtype = torch.uint8)))
                trainy = torch.cat((trainy, torch.tensor(temp_y, dtype = torch.uint8)))
                user_groups[n]=np.arange(start,start+len(temp_x))
                start+=len(temp_x)
            trainy = torch.argmax(trainy,1)
            torch.save(trainx, data_dir+'train/xdata.pt')
            torch.save(trainy, data_dir+'train/ydata.pt')
            torch.save(user_groups,data_dir+'train/user_groups.pt')

            # prepare test set
            with open(data_dir+'/test/data.json', 'r') as inf:
                data = json.load(inf)
            for u in tqdm(data['users']):
                temp_x = process_x(data['user_data'][u]['x'])
                temp_y = process_y(data['user_data'][u]['y'])
                testx = torch.cat((testx, torch.tensor(temp_x, dtype = torch.uint8)))
                testy = torch.cat((testy, torch.tensor(temp_y, dtype = torch.uint8)))
            testy = torch.argmax(testy,1)
            torch.save(testx, data_dir+'test/xdata.pt')
            torch.save(testy, data_dir+'test/ydata.pt')
        
        train_dataset = Data.TensorDataset(trainx.long(),trainy.long()) 
        test_dataset = Data.TensorDataset(testx.long(), testy.long())
        if args.shards_per_client>1:
            new_user_groups = {}
            remain_role = set(range(len(user_groups.keys())))
            i=0
            while len(remain_role)>=args.shards_per_client:
                idxs = []
                s = np.random.choice(list(remain_role),args.shards_per_client,replace=False)
                remain_role-=set(s)
                for r in s:
                    idxs.append(user_groups[r])
                new_user_groups[i]=np.concatenate(idxs,0)
                i+=1
            user_groups=new_user_groups
        args.num_users=len(user_groups.keys())
    else:
        raise RuntimeError("Not registered dataset! Please register it in utils.py")
    
    user_tar = []
    for user in range(args.num_users):
        tar = Counter(np.array(train_dataset.targets)[np.array(user_groups[user],dtype=np.int64)]).most_common(1)[0][0]
        user_tar.append(tar)
    
    return user_tar

# Load data
args = plot_args_parser()
plot_gpr = args.gpr
plot_powerd = args.power_d
plot_afl = args.afl
plot_gpr_0 = args.plot_gpr0
plot_iid = args.plot_iid
plot_center = args.plot_center
plot_loss = args.plot_loss
plot_legend = args.legend_int or args.legend_beta or args.legend_norm
plot_tsne = args.plot_tsne
plot_std_alpha=0.0
acc_th = args.target_accuracy
fs = 36
ls = 28
lw = 2.0 if plot_legend and not args.merge_legend else 1.0
center_acc = {'fmnist':0.8592,'cifar':0.7384}
if args.seed is None:
    args.seed = [None,]

rfile_name = './save/objects/{}_{}_{}_{}_C[{}]_iid[{}]_{}[{}]_E[{}]_B[{}]_mu[{}]_lr[{:.5f}]/random.pkl'.\
                    format(args.dataset,'FedProx[%.3f]'%args.mu if args.FedProx else 'FedAvg', args.model, args.epochs,args.frac, args.iid,
                    'sp' if args.alpha is None else 'alpha',args.shards_per_client if args.alpha is None else args.alpha,
                    args.local_ep, args.local_bs,args.mu,args.lr)
save_path='./save/figs/{}_{}_{}_{}_C[{}]_iid[{}]_{}[{}]_E[{}]_B[{}]_mu[{}]_lr[{:.5f}]'.\
                    format(args.dataset,'FedProx[%.3f]'%args.mu if args.FedProx else 'FedAvg', args.model, args.epochs,args.frac, args.iid,
                    'sp' if args.alpha is None else 'alpha',args.shards_per_client if args.alpha is None else args.alpha,
                    args.local_ep, args.local_bs,args.mu,args.lr)

train_loss = []
train_accuracy = []
chosen_clients = []
weights = []
gpr_sd = []
gt_global_losses = []
test_accuracy = []
test_accuracy_th = []
for seed in args.seed:
    fn = rfile_name.replace('.pkl','_{}.pkl'.format(seed))
    f = open(fn,'rb')
    tl, ta,cc,w,gsd,ggl,tea = pickle.load(f)
    train_loss.append(tl)
    train_accuracy.append(ta)
    chosen_clients.append(cc)
    weights.append(w)
    gpr_sd.append(gsd)
    gt_global_losses.append(ggl)
    test_accuracy.append(tea)
    for i,v in enumerate(tea):
        if v>acc_th:
            test_accuracy_th.append(i)
            break
    f.close()
print("Random for accuracy {}: {} ({})".format(acc_th,np.mean(test_accuracy_th)+1,np.std(test_accuracy_th)))
weights = np.array(weights)
mtrain_accuracy = np.mean(train_accuracy,axis=0)
mean_gt_global_losses = np.sum(gt_global_losses*weights.reshape([weights.shape[0],1,weights.shape[1]]),axis=2)
mtest_accuracy = np.mean(test_accuracy,axis = 0)

mtrain_loss = np.mean(train_loss,axis = 0)
strain_loss = np.std(train_loss,axis=0)
strain_accuracy = np.std(train_accuracy,axis=0)
mgt_global_losses = np.mean(mean_gt_global_losses,axis = 0)
sgt_global_losses = np.std(mean_gt_global_losses,axis=0)
stest_accuracy = np.std(test_accuracy,axis=0)

if plot_gpr:
    num_settings = max([len(args.GPR_interval),len(args.group_size),len(args.discount)])
    while len(args.GPR_interval)<num_settings:
        args.GPR_interval.append(args.GPR_interval[-1])
    while len(args.group_size)<num_settings:
        args.group_size.append(args.group_size[-1])
    while len(args.discount)<num_settings:
        args.discount.append(args.discount[-1])

    save_path = save_path+'_int{}_gp{}_{}[{}]'.format(args.GPR_interval,args.group_size,
                            args.discount_method,args.loss_power if args.discount_method=='loss' else args.discount)
    
    mgt_global_losses_gpr = []
    sgt_global_losses_gpr = []
    mtest_accuracy_gpr = []
    stest_accuracy_gpr = []
    max_test_accuracy_gpr = []
    chosen_clients_settings_gpr = []
    for i in range(num_settings):
        file_name = rfile_name.replace('random','gpr[int{}_gp{}_norm{}]_{}[{}]'.\
                                        format(args.GPR_interval[i],args.group_size[i],args.poly_norm,
                                        args.discount_method,args.loss_power if args.discount_method=='loss' else args.discount[i]))
        train_loss_gpr = []
        train_accuracy_gpr = []
        chosen_clients_gpr = []
        weights_gpr = []
        gpr_sd_gpr = []
        gt_global_losses_gpr = []
        test_accuracy_gpr = []
        test_accuracy_th_gpr = []
        for seed in args.seed:
            fn = file_name.replace('.pkl','_{}.pkl'.format(seed))
            f = open(fn,'rb')
            tl, ta,cc,w,gsd,ggl,tea = pickle.load(f)
            train_loss_gpr.append(tl)
            train_accuracy_gpr.append(ta)
            chosen_clients_gpr.append(cc)
            weights_gpr.append(w)
            gpr_sd_gpr.append(gsd)
            gt_global_losses_gpr.append(ggl)
            test_accuracy_gpr.append(tea)
            for i,v in enumerate(tea):
                if v>acc_th:
                    test_accuracy_th_gpr.append(i)
                    break
            f.close()
        print("GPR for accuracy {}: {} ({})".format(acc_th,np.mean(test_accuracy_th_gpr)+1,np.std(test_accuracy_th_gpr)))
        weights_gpr = np.array(weights_gpr)
        mtrain_accuracy_gpr = np.mean(train_accuracy_gpr,axis=0)
        mean_gt_global_losses_gpr = np.sum(gt_global_losses_gpr*weights_gpr.reshape([weights_gpr.shape[0],1,weights_gpr.shape[1]]),axis=2)
        mtrain_loss_gpr = np.mean(train_loss_gpr,axis = 0)
        strain_loss_gpr = np.std(train_loss_gpr,axis=0)
        strain_accuracy_gpr = np.std(train_accuracy_gpr,axis=0)

        mgt_global_losses_gpr.append(np.mean(mean_gt_global_losses_gpr,axis = 0))
        sgt_global_losses_gpr.append(np.std(mean_gt_global_losses_gpr,axis = 0))
        mtest_accuracy_gpr.append(np.mean(test_accuracy_gpr,axis = 0))
        stest_accuracy_gpr.append(np.std(test_accuracy_gpr,axis = 0))

        max_test_accuracy_gpr.append(np.max(test_accuracy_gpr,axis=1))
        chosen_clients_settings_gpr.append(chosen_clients_gpr)


    
if plot_powerd:
    save_path = save_path+'_d[{}]'.format(args.d)
    file_name = rfile_name.replace('random','powerd_d[{}]'.format(args.d))
    train_loss_pd = []
    train_accuracy_pd = []
    chosen_clients_pd = []
    weights_pd = []
    gpr_sd_pd = []
    gt_global_losses_pd = []
    test_accuracy_pd = []
    test_accuracy_th_pd = []
    for seed in args.seed:
        fn = file_name.replace('.pkl','_{}.pkl'.format(seed))
        f = open(fn,'rb')
        tl, ta,cc,w,gsd,ggl,tea = pickle.load(f)
        train_loss_pd.append(tl)
        train_accuracy_pd.append(ta)
        chosen_clients_pd.append(cc)
        weights_pd.append(w)
        gpr_sd_pd.append(gsd)
        gt_global_losses_pd.append(ggl)
        test_accuracy_pd.append(tea)
        for i,v in enumerate(tea):
            if v>acc_th:
                test_accuracy_th_pd.append(i)
                break
        f.close()
    print("Power-D for accuracy {}: {} ({})".format(acc_th,np.mean(test_accuracy_th_pd)+1,np.std(test_accuracy_th_pd)))
    weights_pd = np.array(weights_pd)
    mtrain_accuracy_pd = np.mean(train_accuracy_pd,axis=0)
    mean_gt_global_losses_pd = np.sum(gt_global_losses_pd*weights_pd.reshape([weights_pd.shape[0],1,weights_pd.shape[1]]),axis=2)
    mtest_accuracy_pd = np.mean(test_accuracy_pd,axis = 0)

    mtrain_loss_pd = np.mean(train_loss_pd,axis = 0)
    strain_loss_pd = np.std(train_loss_pd,axis=0)
    strain_accuracy_pd = np.std(train_accuracy_pd,axis=0)
    mgt_global_losses_pd = np.mean(mean_gt_global_losses_pd,axis = 0)
    sgt_global_losses_pd = np.std(mean_gt_global_losses_pd,axis=0)
    stest_accuracy_pd = np.std(test_accuracy_pd,axis=0)

if plot_afl:
    file_name = rfile_name.replace('random','afl')
    train_loss_afl = []
    train_accuracy_afl = []
    chosen_clients_afl = []
    weights_afl = []
    gpr_sd_afl = []
    gt_global_losses_afl = []
    test_accuracy_afl = []
    test_accuracy_th_afl = []
    for seed in args.seed:
        fn = file_name.replace('.pkl','_{}.pkl'.format(seed))
        f = open(fn,'rb')
        tl, ta,cc,w,gsd,ggl,tea = pickle.load(f)
        train_loss_afl.append(tl)
        train_accuracy_afl.append(ta)
        chosen_clients_afl.append(cc)
        weights_afl.append(w)
        gpr_sd_afl.append(gsd)
        gt_global_losses_afl.append(ggl)
        test_accuracy_afl.append(tea)
        for i,v in enumerate(tea):
            if v>acc_th:
                test_accuracy_th_afl.append(i)
                break
        f.close()
    print("AFL for accuracy {}: {} ({})".format(acc_th,np.mean(test_accuracy_th_afl)+1,np.std(test_accuracy_th_afl)))
    weights_afl = np.array(weights_afl)
    mtrain_accuracy_afl = np.mean(train_accuracy_afl,axis=0)
    mean_gt_global_losses_afl = np.sum(gt_global_losses_afl*weights_afl.reshape([weights_afl.shape[0],1,weights_afl.shape[1]]),axis=2)
    mtest_accuracy_afl = np.mean(test_accuracy_afl,axis = 0)

    mtrain_loss_afl = np.mean(train_loss_afl,axis = 0)
    strain_loss_afl = np.std(train_loss_afl,axis=0)
    strain_accuracy_afl = np.std(train_accuracy_afl,axis=0)
    mgt_global_losses_afl = np.mean(mean_gt_global_losses_afl,axis = 0)
    sgt_global_losses_afl = np.std(mean_gt_global_losses_afl,axis=0)
    stest_accuracy_afl = np.std(test_accuracy_afl,axis=0)

if plot_gpr_0:
    file_name = rfile_name.replace('random','gpr0')
    train_loss_gpr_0 = []
    train_accuracy_gpr_0 = []
    chosen_clients_gpr_0 = []
    weights_gpr_0 = []
    gpr_sd_gpr_0 = []
    gt_global_losses_gpr_0 = []
    test_accuracy_gpr_0 = []
    for seed in args.seed:
        fn = file_name.replace('.pkl','_{}.pkl'.format(seed))
        f = open(fn,'rb')
        tl, ta,cc,w,gsd,ggl,tea = pickle.load(f)
        train_loss_gpr_0.append(tl)
        train_accuracy_gpr_0.append(ta)
        chosen_clients_gpr_0.append(cc)
        weights_gpr_0.append(w)
        gpr_sd_gpr_0.append(gsd)
        gt_global_losses_gpr_0.append(ggl)
        test_accuracy_gpr_0.append(tea)
        f.close()
    weights_gpr_0 = np.array(weights_gpr_0)
    mtrain_accuracy_gpr_0 = np.mean(train_accuracy_gpr_0,axis=0)
    mean_gt_global_losses_gpr_0 = np.sum(gt_global_losses_gpr_0*weights_gpr_0.reshape([weights_gpr_0.shape[0],1,weights_gpr_0.shape[1]]),axis=2)
    mtest_accuracy_gpr_0 = np.mean(test_accuracy_gpr_0,axis = 0)

    mtrain_loss_gpr_0 = np.mean(train_loss_gpr_0,axis = 0)
    strain_loss_gpr_0 = np.std(train_loss_gpr_0,axis=0)
    strain_accuracy_gpr_0 = np.std(train_accuracy_gpr_0,axis=0)
    mgt_global_losses_gpr_0 = np.mean(mean_gt_global_losses_gpr_0,axis = 0)
    sgt_global_losses_gpr_0 = np.std(mean_gt_global_losses_gpr_0,axis=0)
    stest_accuracy_gpr_0 = np.std(test_accuracy_gpr_0,axis=0)

if plot_iid:
    file_name = './save/objects/{}_{}_{}_{}_C[{}]_iid[1]_E[{}]_B[{}]_mu[{}]_lr[{:.5f}]/random_None.pkl'.\
        format(args.dataset, 'FedProx[%.3f]'%args.mu if args.FedProx else 'FedAvg',args.model,args.epochs, args.frac,
            args.local_ep, args.local_bs,args.mu,args.lr)
    f = open(file_name,'rb')
    train_loss_iid, train_accuracy_iid,chosen_clients_iid,weights_iid,_,gt_global_losses_iid,test_accuracy_iid = pickle.load(f)
    for i,v in enumerate(test_accuracy_iid):
        if v>acc_th:
            print("IID for accuracy {}: {}".format(acc_th,i+1))
            break
    mean_gt_global_losses_iid = np.sum(gt_global_losses_iid*np.expand_dims(weights_iid,0),axis=1)

if not os.path.exists(save_path):
    os.makedirs(save_path)

if plot_loss:
# Plot Loss curve
    fig0,(ax1,ax2)=plt.subplots(1,2,figsize=(16,6))
    y_major_locator = plt.MultipleLocator(0.2)
    #plt.title('Training Loss vs Communication rounds')
    if plot_std_alpha>0:
        ax1.fill_between(range(len(mgt_global_losses)),mgt_global_losses-sgt_global_losses,mgt_global_losses+sgt_global_losses,alpha=plot_std_alpha)
    ax1.plot(range(len(mgt_global_losses)), mgt_global_losses,linewidth=lw)
    
    if plot_powerd:
        if plot_std_alpha>0:
            ax1.fill_between(range(len(mgt_global_losses_pd)),mgt_global_losses_pd-sgt_global_losses_pd,mgt_global_losses_pd+sgt_global_losses_pd,alpha=plot_std_alpha)
        ax1.plot(range(len(mgt_global_losses_pd)), mgt_global_losses_pd, linewidth=lw)

    if plot_gpr_0:
        if plot_std_alpha>0:
            ax1.fill_between(range(len(mgt_global_losses_gpr_0)),mgt_global_losses_gpr_0-sgt_global_losses_gpr_0,mgt_global_losses_gpr_0+sgt_global_losses_gpr_0,alpha=plot_std_alpha)
        ax1.plot(range(len(mgt_global_losses_gpr_0)), mgt_global_losses_gpr_0, linewidth=lw)
    
    if plot_afl:
        if plot_std_alpha>0:
            ax1.fill_between(range(len(mgt_global_losses_afl)),mgt_global_losses_afl-sgt_global_losses_afl,mgt_global_losses_afl+sgt_global_losses_afl,alpha=plot_std_alpha)
        ax1.plot(range(len(mgt_global_losses_afl)), mgt_global_losses_afl, linewidth=lw)

    if plot_gpr:
        for i in range(num_settings):
            if plot_std_alpha>0:
                ax1.fill_between(range(len(mgt_global_losses_gpr[i])),mgt_global_losses_gpr[i]-sgt_global_losses_gpr[i],mgt_global_losses_gpr[i]+sgt_global_losses_gpr[i],alpha=plot_std_alpha)
            ax1.plot(range(len(mgt_global_losses_gpr[i])), mgt_global_losses_gpr[i], linewidth=lw)

    if plot_iid:
        ax1.plot(range(len(mean_gt_global_losses_iid)), mean_gt_global_losses_iid,linewidth=lw)

    ax1.set_ylabel('Training loss',fontsize = fs)
    ax1.set_xlabel('Communication Round',fontsize=fs)
    ax1.tick_params(labelsize = ls)

    # plt.ylim((1.4,3.5))
    ax1.yaxis.set_major_locator(y_major_locator)

elif plot_legend and not args.merge_legend:
    fig0,ax2 = plt.subplots(figsize=(24,6))
else:
    fig0,ax2 = plt.subplots(figsize=(8,6))
num_column = 1
# fig=plt.figure(figsize=(8,8))
y_major_locator = plt.MultipleLocator(0.1)
# plt.title('Test Accuracy vs Communication rounds')
if plot_center:
    ax2.plot(range(len(mtest_accuracy)), np.ones(len(mtest_accuracy))*center_acc[args.dataset],'--',label = 'cent',linewidth=lw)
    num_column += 1
if plot_std_alpha>0:
    ax2.fill_between(range(len(mtest_accuracy)),mtest_accuracy-stest_accuracy,mtest_accuracy+stest_accuracy,alpha=plot_std_alpha)
ax2.plot(range(len(mtest_accuracy)), mtest_accuracy,label = 'Rand',linewidth=lw)
print("Random Accuracy: ",np.max(test_accuracy,axis=1))
if plot_afl:
    if plot_std_alpha>0:
        ax2.fill_between(range(len(mtest_accuracy_afl)),mtest_accuracy_afl-stest_accuracy_afl,mtest_accuracy_afl+stest_accuracy_afl,alpha=plot_std_alpha)
    ax2.plot(range(len(mtest_accuracy_afl)), mtest_accuracy_afl, label = 'AFL',linewidth=lw)
    print("AFL Accuracy: ",np.max(test_accuracy_afl,axis=1))
    num_column+=1
if plot_powerd:
    if plot_std_alpha>0:
        ax2.fill_between(range(len(mtest_accuracy_pd)),mtest_accuracy_pd-stest_accuracy_pd,mtest_accuracy_pd+stest_accuracy_pd,alpha=plot_std_alpha)
    ax2.plot(range(len(mtest_accuracy_pd)), mtest_accuracy_pd, label = 'Pow-d',linewidth=lw)
    print("Power-D Accuracy: ",np.max(test_accuracy_pd,axis=1))
    num_column+=1
if plot_gpr_0:
    if plot_std_alpha>0:
        ax2.fill_between(range(len(mtest_accuracy_gpr_0)),mtest_accuracy_gpr_0-stest_accuracy_gpr_0,mtest_accuracy_gpr_0+stest_accuracy_gpr_0,alpha=plot_std_alpha)
    ax2.plot(range(len(mtest_accuracy_gpr_0)),mtest_accuracy_gpr_0, label = 'GPR Selection(Baseline)',linewidth=lw)
    print("GPR_0 Accuracy: ",np.max(test_accuracy_gpr_0,axis=1))
    num_column+=1

if plot_gpr:
    for i in range(num_settings):
        if args.legend_int:
            label = 'FedCor,$\Delta t={}$'.format(args.GPR_interval[i])
        elif args.legend_beta:
            label = 'FedCor,$\\beta={}$'.format(args.discount[i])
        else:
            label = 'FedCor'
        if plot_std_alpha>0:
            ax2.fill_between(range(len(mtest_accuracy_gpr[i])),mtest_accuracy_gpr[i]-stest_accuracy_gpr[i],mtest_accuracy_gpr[i]+stest_accuracy_gpr[i],alpha=plot_std_alpha)
        ax2.plot(range(len(mtest_accuracy_gpr[i])), mtest_accuracy_gpr[i],label = label, linewidth=lw)
        print("GPR Accuracy: ",max_test_accuracy_gpr[i])
        num_column+=1
if plot_iid:
    ax2.plot(range(len(test_accuracy_iid)), test_accuracy_iid, label = 'IID',linewidth=lw)
    print("IID Accuracy:%f"%max(test_accuracy_iid))
    # num_column+=1


ax2.set_ylabel('Test Accuracy',fontsize = fs)
ax2.set_xlabel('Communication Round',fontsize = fs)
ax2.tick_params(labelsize = ls)
ax2.yaxis.set_major_locator(y_major_locator)
ax2.set_xlim(0,args.epochs)
# ax2.legend()
# plt.ylim((0.4,2.6))

if plot_legend:
    if args.merge_legend:
        if args.plot_title:
            title = "FMNIST, " if args.dataset=='fmnist' else "CIFAR-10, "
            title += "Dir" if args.alpha is not None else "%dSPC"%args.shards_per_client
            ax2.set_title(title,fontsize=fs)
        fig0.legend(loc = 'lower right',bbox_to_anchor=(0.955,0.17),fontsize = ls-2)
        fig0.tight_layout()
        fig0.savefig(save_path+'/accuracy.png')
    else:
        fig0.legend(loc='center',ncol=num_column, borderaxespad=0.,fontsize = fs)
        ax2.remove()
        # fig.tight_layout()
        fig0.savefig(save_path+'/legend.png')
else:
    if args.plot_title:
        title = "FMNIST, " if args.dataset=='fmnist' else "CIFAR-10, "
        title += "Dir" if args.alpha is not None else "%dSPC"%args.shards_per_client
        ax2.set_title(title,fontsize=fs)
    fig0.tight_layout()
    fig0.savefig(save_path+'/accuracy.png')

for n,s in enumerate(args.seed):
    if not args.sep_hist:
        fig1,axs = plt.subplots(num_column,1)
    else:
        axs = []
        figs = []
        for _ in range(num_column):
            fig,ax = plt.subplots(figsize=(8,6))
            axs.append(ax)
            figs.append(fig)
    idxs=0
    title = "Dir, Rand" if args.alpha is not None else "%dSPC, Rand"%args.shards_per_client
    axs[0].set_title(title,fontsize=fs)
    axs[0].hist(np.array(chosen_clients[n]).flatten(),range(args.num_users+1))
    axs[0].tick_params(labelsize=ls)
    # axs[0].set_ylabel('Selected Times',fontsize=26)
    axs[0].set_xlabel('Client Index',fontsize=fs)
    if args.sep_hist:
        figs[0].tight_layout()
        figs[0].savefig(save_path+'/random_dist_{}.png'.format(s))
    if plot_gpr:
        for i in range(num_settings):
            idxs+=1
            title = "Dir, " if args.alpha is not None else "%dSPC, "%args.shards_per_client
            title += 'FedCor, $\\beta={}$'.format(args.discount[i])
            axs[idxs].set_title(title,fontsize=fs)
            axs[idxs].hist(np.array(chosen_clients_settings_gpr[i][n]).flatten(),range(args.num_users+1))
            axs[idxs].tick_params(labelsize=ls)
            # axs[1].set_ylabel('Frequency',fontsize=26)
            axs[idxs].set_xlabel('Client Index',fontsize=fs)
            if args.sep_hist:
                figs[idxs].tight_layout()
                figs[idxs].savefig(save_path+'/gpr[{}]_dist_{}.png'.format(args.discount[i],s))
    if plot_gpr_0:
        idxs+=1
        axs[idxs].set_title("GPR_0",fontsize=fs)
        axs[idxs].hist(np.array(chosen_clients_gpr_0[n]).flatten(),range(args.num_users+1))
        axs[idxs].tick_params(labelsize=ls)
        # axs[2].set_ylabel('Frequency',fontsize=26)
        axs[idxs].set_xlabel('Client Index',fontsize=fs)
        if args.sep_hist:
            figs[idxs].tight_layout()
            figs[idxs].savefig(save_path+'/gpr0_dist_{}.png'.format(s))
    if plot_powerd:
        idxs+=1
        axs[idxs].set_title("Pow-d",fontsize=fs)
        axs[idxs].hist(np.array(chosen_clients_pd[n]).flatten(),range(args.num_users+1))
        axs[idxs].tick_params(labelsize=ls)
        # axs[3].set_ylabel('Selected Times')
        axs[idxs].set_xlabel('Client Index')
        if args.sep_hist:
            figs[idxs].tight_layout()
            figs[idxs].savefig(save_path+'/powd_dist_{}.png'.format(s))
    if plot_afl:
        idxs+=1
        axs[idxs].set_title("AFL",fontsize=fs)
        axs[idxs].hist(np.array(chosen_clients_afl[n]).flatten(),range(args.num_users+1))
        axs[idxs].tick_params(labelsize=ls)
        # axs[4].set_ylabel('Selected Times')
        axs[idxs].set_xlabel('Client Index')
        if args.sep_hist:
            figs[idxs].tight_layout()
            figs[idxs].savefig(save_path+'/afl_dist_{}.png'.format(s))
    if not args.sep_hist:
        fig1.savefig(save_path+'/selection_dist_{}.png'.format(s))


if plot_tsne:
    for n,seed in enumerate(args.seed):
        user_label = get_plot_dataset(args,seed = seed)
        sd = gpr_sd_gpr[n]
        tsne = TSNE(n_components = 2,init= 'pca')
        X = sd['Projection.PMatrix'].numpy()
        X = (X/np.sqrt(np.sum(X**2,axis=0))).transpose()# Normalize the vector to length 1
        Y = tsne.fit_transform(X)
        fig = plt.figure(figsize=(8,8))
        plt.scatter(Y[:,0],Y[:,1],c = user_label,cmap=plt.cm.Spectral)
        plt.tick_params(labelsize = ls)
        plt.title('FMNIST' if args.dataset=='fmnist' else 'CIFAR-10',fontsize = fs)
        
        fig.tight_layout()
        fig.savefig(save_path+'/TSNE_{}.png'.format(seed))


# plt.show()
