import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import copy

class biclassifier(nn.Module):
    def __init__(self,init_value = [-1.0,1.0]):
        super(biclassifier, self ).__init__()
        self.linear = nn.Linear(2,1,bias = False)
        # self.a = torch.nn.parameter.Parameter(torch.tensor(init_value))
        self.sigmoid = nn.Sigmoid()
        self.linear.weight.data = torch.tensor(init_value).reshape(self.linear.weight.data.shape)

    def forward(self,x):
        return self.sigmoid(self.linear(x))
    
    def visualization(self,fig,marker='-',label = None,c=None):
        ticks = np.linspace(0,1.2,100)
        weight = self.linear.weight.data.clone().detach().numpy().flatten()
        y = -ticks*weight[0]/weight[1]
        fig.plot(ticks,y,marker,label=label,c=c)

    def update(self,data):
        k = 0
        # if len(data.shape)==1:
        data = data.reshape([-1,2])
        for d in data:
            k+=d[0]/d[1]
        self.linear.weight.data = torch.tensor([k,1.0]).reshape(self.linear.weight.data.shape)
    
    def avg(self,model):
        res = copy.deepcopy(model)
        weight = self.linear.weight.data.clone().detach().numpy().flatten()
        weight/=weight[1]
        oweight = model.linear.weight.data.clone().detach().numpy().flatten()
        oweight/=oweight[1]
        res.linear.weight.data = torch.tensor((weight+oweight)/2).reshape(res.linear.weight.data.shape)
        return res

    

def biclass_loss(output,label):
    return torch.mean(-(1-label)*torch.log(1-output)-label*torch.log(output))

def client_loss(model,data,label,data_partition):
    model.eval()
    loss_list = []
    for i in range(len(data_partition)):
        data_input = torch.tensor(data[data_partition[i]])
        label_input = torch.tensor(label[data_partition[i]])
        loss_list.append(biclass_loss(model(data_input),label_input).detach().item())
    return loss_list


def Ellipse(center,ax1,ax2,angle,fig,c = None):
    angle_tick = np.linspace(0,2*np.pi,100)
    x = ax1*np.cos(angle_tick)
    y = ax2*np.sin(angle_tick)
    points = np.concatenate([x.reshape([1,-1]),y.reshape([1,-1])],axis=0)
    rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
    rot_points = rotation_matrix.dot(points)
    rot_points+=np.array(center).reshape([2,1])
    fig.plot(rot_points[0,:],rot_points[1,:],c=c)

def disc_uniform(center,radius,size):
    data = np.zeros([size,2])
    angles = np.random.uniform(low=0.0,high=2*np.pi,size=size)
    r = np.sqrt(np.random.uniform(low=0.0,high=radius**2,size = size))
    data[:,0]=r*np.cos(angles)
    data[:,1]=r*np.sin(angles)
    return data+center


plot_annotate = False

np.random.seed(seed=17)
edge_colors = np.array(['r','k'])
# data = np.array([[0.5,i/4] for i in [-3,-2,-1,1,2,3]])
# data = np.array([[0.25,0.5],[0.5,0.5],[0.75,0.5],[0.25,-0.5],[0.5,-0.5],[0.75,-0.5]])
# data = np.concatenate([data,np.array([[i/5,0.0] for i in range(1,5)])],axis = 0)

# angles = np.array([5,4,3,2,1,-1,-2,-3,-4,-5]).reshape([-1,1])*np.pi/12
# data = np.concatenate([np.cos(angles),np.sin(angles)],axis = 1)
label_num = [13,7]
r = 0.2

centers = np.array([[np.cos(np.pi/3),np.sin(np.pi/3)],[np.cos(np.pi/6),np.sin(np.pi/6)]])
# data = np.random.normal(scale=0.05,size=[np.sum(label_num),2])
# data[:label_num[0],:]+=centers[0,:]
# data[label_num[0]:,:]+=centers[1,:]
data =np.concatenate([disc_uniform(centers[0,:],r,label_num[0]),disc_uniform(centers[1,:],r,label_num[1])],axis=0)
data[13,:]+=np.array([0.03,-0.03])
data[11,:]+=np.array([-0.03,0.0])
data[3,:]+=np.array([0.0,0.03])
label = np.array(data[:,1]>data[:,0],dtype=np.int64)

data_partition = np.array([[i,] for i in range(len(data))])
# color_map = ['orange','blue','purple']
tick = np.linspace(0,1.2,100)

# Initialize global model
init_model = biclassifier(np.array([-0.7,1.0],dtype=np.double))



# plot data

# fig0 for initialzation
plt.figure(figsize=(6,6))
fig0 = plt.gca()
plt.xticks([0.0,0.3,0.6,0.9,1.2])
plt.yticks([0.0,0.3,0.6,0.9,1.2])
plt.tick_params(labelsize=20)
fig0.spines['right'].set_visible(False)
fig0.spines['top'].set_color('none')
fig0.xaxis.set_ticks_position('bottom')
fig0.spines['bottom'].set_position(('data',0))
fig0.yaxis.set_ticks_position('left')
fig0.spines['left'].set_position(('data',0))
fig0.scatter(data[:,0],data[:,1],c=label)
print(len(data))
for i in range(len(data)):
    fig0.annotate('%d'%i,xy=data[i,:],fontsize=20)
# Ellipse(centers[0,:],r,r,0,fig0)
# Ellipse(centers[1,:],r,r,0,fig0)
# plot groundtruth decision boundary
bound = fig0.plot(tick,tick,'k--',label='Ground Truth')
# plot init boundary
init_model.visualization(fig0,'-.','Initialization',c='purple')
fig0.legend(fontsize=26,loc='center left')

# fig1 for one client selection
plt.figure(figsize=(6,6))
fig1 = plt.gca()
plt.xticks([0.0,0.3,0.6,0.9,1.2])
plt.yticks([0.0,0.3,0.6,0.9,1.2])
plt.tick_params(labelsize=20)
fig1.spines['right'].set_visible(False)
fig1.spines['top'].set_color('none')
fig1.xaxis.set_ticks_position('bottom')
fig1.spines['bottom'].set_position(('data',0))
fig1.yaxis.set_ticks_position('left')
fig1.spines['left'].set_position(('data',0))
fig1.scatter(data[:,0],data[:,1],c=label)
for i in range(len(data)):
    fig1.annotate('%d'%i,xy=data[i,:],fontsize=20)
# Ellipse(centers[0,:],r,r,0,fig1)
# Ellipse(centers[1,:],r,r,0,fig1)
# plot groundtruth decision boundary
bound = fig1.plot(tick,tick,'k--',label='Ground Truth')
# plot init boundary
init_model.visualization(fig1,'-.','Initialization',c='purple')


# ax2 for two client selection
# plt.figure(figsize=(10,6))
fig2,ax2 = plt.subplots(figsize=(12,6))

plt.tick_params(labelsize=20)
plt.xticks([0.0,0.3,0.6,0.9,1.2])
plt.yticks([0.0,0.3,0.6,0.9,1.2])
plt.xlim(0,1.2)
plt.ylim(0,1.2)
# ax2.spines['right'].set_visible(False)
# ax2.spines['top'].set_color('none')
ax2.xaxis.set_ticks_position('bottom')
ax2.spines['bottom'].set_position(('data',0))
ax2.yaxis.set_ticks_position('left')
ax2.spines['left'].set_position(('data',0))

ax2.scatter(data[:label_num[0],0],data[:label_num[0],1],c='k',label='Positive')
ax2.scatter(data[label_num[0]:,0],data[label_num[0]:,1],c='r',label='Negative')

# ax2.scatter(data[:,0],data[:,1],c='none',marker='o',s=350,edgecolors=edge_colors[label])
if plot_annotate:
    for i in range(len(data)):
        if i != 16 and i!=3 and i!=11:
            ax2.annotate('%d'%i,xy=data[i,:],fontsize=14)
            # ax2.annotate('%d'%i,xy = data[i,:],xytext=data[i,:]-np.array([0.025,0.025]),fontsize=14,c=edge_colors[label[i]])
        elif i==3:
            ax2.annotate('%d'%i,xy=data[i,:]-np.array([0.04,0.02]),fontsize=14)
        elif i==16:
            ax2.annotate('%d'%i,xy=data[i,:]-np.array([0.06,0.05]),fontsize=14)
        elif i==11:
            ax2.annotate('%d'%i,xy=data[i,:]-np.array([0.06,0.02]),fontsize=14)

# Ellipse(centers[0,:],r,r,0,ax2)
# Ellipse(centers[1,:],r,r,0,ax2)
# plot groundtruth decision boundary
bound = ax2.plot(tick,tick,'k--',label='Ground Truth')
init_model.visualization(ax2,'-.','Initialization',c='purple')


# calculate initail loss
init_loss = np.array(client_loss(init_model,data,label,data_partition))
updated_models = []

# one client selection
part_loss = []
total_loss = []
for i in range(len(data_partition)):
    model = copy.deepcopy(init_model)

    opt = torch.optim.SGD(model.parameters(),lr = 0.9)
    model.train()
    opt.zero_grad()
    data_input = torch.tensor(data[data_partition[i]])
    label_input = torch.tensor(label[data_partition[i]])
    output = model(data_input)
    loss = biclass_loss(output,label_input)
    loss.backward()
    opt.step()

    updated_models.append(copy.deepcopy(model))
    # model.update(data[data_partition[i]])

    # model.visualization(fig,'-',"%d"%i)
    new_loss = np.array(client_loss(model,data,label,data_partition))
    total_loss.append(np.mean(new_loss))
    part_loss.append(new_loss-init_loss)


part_loss = np.array(part_loss).transpose()
total_loss = np.array(total_loss)
covar = np.cov(part_loss,bias=True)
var = np.diagonal(covar)
corr = (covar/np.sqrt(var.reshape([-1,1])))/np.sqrt(var.reshape([1,-1]))

ranks = []
for i in range(len(var)):
    r = 0
    for j in range(len(var)):
        r+=corr[i,j]*np.sqrt(var[j])
    ranks.append(r/len(var))
ranks = np.array(ranks)
print("===================One Client Selection===================")
# print(corr)
print("Init Loss:",init_loss)
print("Total Loss",total_loss)
print("FedGP Criterion:",ranks)

ind_sel = init_loss.argsort()[::-1]
gp_sel = ranks.argsort()[::-1]
gt_sel = total_loss.argsort()
print("Independent Rank:",ind_sel)
print("FedGP Rank:",gp_sel)
print("Groundtruth Rank:",gt_sel)
updated_models[ind_sel[0]].visualization(fig1,'-',"Ind Sel: %d"%ind_sel[0])
updated_models[gp_sel[0]].visualization(fig1,'-',"GP Sel: %d"%gp_sel[0])
updated_models[gt_sel[0]].visualization(fig1,'-',"GT Sel: %d"%gt_sel[0])
fig1.legend(fontsize=26)

# two client selection
part_loss = []
min_loss = np.inf
averaged_models = []
total_loss=[]
for i in range(len(updated_models)-1):
    for j in range(i+1,len(updated_models)):
        avg_model = updated_models[i].avg(updated_models[j])
        
        new_loss = np.array(client_loss(avg_model,data,label,data_partition))
        total_loss.append(np.mean(new_loss))
        part_loss.append(new_loss-init_loss)

        if i == np.min(ind_sel[0:2]) and j == np.max(ind_sel[0:2]):
            ind_pos = len(averaged_models)

        if np.mean(new_loss)<min_loss:
            min_loss=np.mean(new_loss)
            gt_sel2 = (i,j)
            gt_pos = len(averaged_models)

        averaged_models.append(avg_model)

part_loss = np.array(part_loss).transpose()
total_loss = np.array(total_loss)

gp_sel2 = np.zeros(2,dtype=np.int64)

covar = np.cov(part_loss,bias=True)+np.eye(sum(label_num))*1e-6

var = np.diagonal(covar)
corr = (covar/np.sqrt(var.reshape([-1,1])))/np.sqrt(var.reshape([1,-1]))
# print(corr)
plt.figure(figsize=(8,6))
fig3 = plt.gca()
plt.tick_params(labelsize=20)
fig3.plot(range(len(var)),np.zeros(len(var)),'k')
fig3.fill_between(range(len(var)),-np.sqrt(var),np.sqrt(var),alpha=0.3)
max_cor = -np.inf
# Select first
for i in range(len(var)):
    r = 0
    for j in range(len(var)):
        r+=covar[i,j]/np.sqrt(var[j])
    if r>max_cor:
        max_cor=r
        gp_sel2[0]=i

fig3.arrow(gp_sel2[0],0,0,-np.sqrt(var[gp_sel2[0]]),length_includes_head=True,width = 0.1,head_width = 0.3,head_length = 0.01,ec = 'r',fc = 'r')
fig3.set_ylabel("Predictive Loss Change",fontsize=26)
fig3.set_xlabel("Client Index",fontsize=26)
# fig3.plot(gp_sel2[0],-np.sqrt(var[gp_sel2[0]]),'rx',markersize = 26)
# Select Second
mu = covar[:,gp_sel2[0]]/(-np.sqrt(var[gp_sel2[0]]))
fig3.plot(range(len(var)),mu,'r-')
plt.tight_layout()

sub = covar[:,gp_sel2[0]:gp_sel2[0]+1].dot(covar[gp_sel2[0]:gp_sel2[0]+1,:])/var[gp_sel2[0]]
covar = covar-sub+np.eye(sum(label_num))*1e-6
var = np.clip(np.diagonal(covar),a_min=0.0,a_max = None)
# print(np.sqrt(var))
plt.figure(figsize=(8,6))
fig4=plt.gca()
plt.tick_params(labelsize=20)
fig4.plot(range(len(var)),mu,'k')
fig4.fill_between(range(len(var)),mu-np.sqrt(var),mu+np.sqrt(var),alpha=0.3)
max_cor = -np.inf
# Select second
for i in range(len(var)):
    if i == gp_sel2[0]:
        continue
    r = 0
    for j in range(len(var)):
        r+=covar[i,j]/(np.sqrt(var[j]))
    if r>max_cor:
        max_cor=r
        gp_sel2[1]=i

fig4.arrow(gp_sel2[1],mu[gp_sel2[1]],0,-np.sqrt(var[gp_sel2[1]]),length_includes_head=True,width = 0.1,head_width = 0.3,head_length = 0.01,ec = 'r',fc = 'r')
# fig4.plot(gp_sel2[1],mu[gp_sel2[1]]-np.sqrt(var[gp_sel2[1]]),'rx',markersize = 26)
mu = mu-covar[:,gp_sel2[1]]/(np.sqrt(var[gp_sel2[1]]))
fig4.plot(range(len(var)),mu,'r-')
fig4.set_ylabel("Predictive Loss Change",fontsize=26)
fig4.set_xlabel("Client Index",fontsize=26)
plt.tight_layout()

gp_pos=0
enum = 0
for i in range(len(updated_models)-1):
    for j in range(i+1,len(updated_models)):
        if i == np.min(gp_sel2) and j == np.max(gp_sel2):
            gp_pos=enum
            break
        enum+=1
    if gp_pos==enum:
        break

print("===================Two Client Selection===================")
print("Independent Selection:",ind_sel[0:2])
print("FedGP Selection:",gp_sel2)
print("Groundtruth Selection:",gt_sel2)
print("Ind Loss",total_loss[ind_pos])
print("GP Loss",total_loss[gp_pos])
print("GT Loss",total_loss[gt_pos])
if plot_annotate:
    averaged_models[ind_pos].visualization(ax2,'-',"Ind: (%d,%d)"%(ind_sel[0],ind_sel[1]),c='b')
    averaged_models[gp_pos].visualization(ax2,'-',"GP: (%d,%d)"%(gp_sel2[0],gp_sel2[1]),c='orange')
    averaged_models[gt_pos].visualization(ax2,'-',"Opt: (%d,%d)"%(gt_sel2[0],gt_sel2[1]),c='g')
else:
    ax2.scatter(data[ind_sel[0:2],0],data[ind_sel[0:2],1],c='none',edgecolors='b',marker='^',linewidths=3,s=400,label='Ind Selection')
    ax2.scatter(data[gp_sel2,0],data[gp_sel2,1],c='none',edgecolors='orange',marker='s',linewidths=3,s=400,label='GP Selection')
    ax2.scatter(data[gt_sel2,0],data[gt_sel2,1],c='none',edgecolors='g',marker='o',linewidths=3,s=400,label='Opt Selection')
    averaged_models[ind_pos].visualization(ax2,'-',"Ind Result",c='b')
    averaged_models[gp_pos].visualization(ax2,'-',"GP Result",c='orange')
    averaged_models[gt_pos].visualization(ax2,'-',"Opt Result",c='g')

h,l = ax2.get_legend_handles_labels()
h=[h[5],h[6],h[0],h[1],h[2],h[3],h[4],h[7],h[8],h[9]]
l=[l[5],l[6],l[0],l[1],l[2],l[3],l[4],l[7],l[8],l[9]]
ax2.legend(h,l,fontsize=26,loc='center',bbox_to_anchor=(-0.5, 0.5))
fig2.subplots_adjust(left=0.5,top=0.95,bottom=0.1,right=0.95)


plt.show()


