from __future__ import print_function
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import imageio
import os
import platform
if platform.system() == 'Darwin':
    import matplotlib
    matplotlib.use('TkAgg')

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(2, 15)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(15, 15)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(15, 15)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        return x


class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(15, 15)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(15, 15)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(15, n_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


def load_data():
    moon_data = np.load('blob_data_s20.npz')
    x_s = moon_data['x_s']
    y_s = moon_data['y_s']
    x_t = moon_data['x_t']
    return x_s, y_s, x_t


def generate_grid_point():
    x_min, x_max = min(x_s[:, 0].min(), x_t[:, 0].min())  - .5, max(x_s[:, 0].max(), x_t[:, 0].max())  + 0.5
    y_min, y_max = min(x_s[:, 1].min(), x_t[:, 1].min()) - .5, max(x_s[:, 1].max(), x_t[:, 1].max()) + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))
    return xx, yy


def seed_everything(seed=2):
    import random
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import os
    os.environ['PYTHONHASHSEED'] = str(seed)


def cross_entropy(out1, out2, reduce=True):
    ent = -torch.sum(out1.log() * out2, dim=1)
    if reduce:
        return torch.mean(ent)
    else:
        return ent


def kl_loss(pred, soft_targets, reduce=True):
    kl = F.kl_div(F.log_softmax(pred, dim=1),F.softmax(soft_targets, dim=1),reduce=False)

    if reduce:
        return torch.mean(torch.sum(kl, dim=1))
    else:
        return torch.sum(kl, 1)


def loss_select(y_1, y_2, t, forget_rate, co_lambda=0.01):
    loss_pick_1 = F.cross_entropy(y_1, t, reduce = False) 
    loss_pick_2 = F.cross_entropy(y_2, t, reduce = False)
    loss_pick = loss_pick_1 + loss_pick_2 + co_lambda * kl_loss(y_1, y_2, reduce=False) + co_lambda * kl_loss(y_2, y_1, reduce=False)

    ind_sorted = torch.argsort(loss_pick)
    loss_sorted = loss_pick[ind_sorted]

    remember_rate = 1 - forget_rate
    num_remember = int(remember_rate * len(loss_sorted))

    ind_update = ind_sorted[:num_remember]

    loss = torch.mean(loss_pick[ind_update])

    return loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--out', type=str, default="plot_toy")
    opts = parser.parse_args()

    if not os.path.exists(opts.out):
        os.makedirs(opts.out)

    n_classes = 3
    weight = 0.05
    lr = 0.005
    pi = 0.5493
    m = 0

    # Load data
    x_s, y_s, x_t = load_data()

    # Set random seed
    seed_everything()

    X = torch.tensor(x_s).float()
    Y = torch.tensor(y_s).long().view(-1)
    X_target = torch.tensor(x_t).float()


    g = Generator()
    c1 = Classifier()
    c2 = Classifier()

    opt_g = optim.SGD(g.parameters(), lr=lr)
    opt_c1 = optim.SGD(c1.parameters(), lr=lr)
    opt_c2 = optim.SGD(c2.parameters(), lr=lr)

    # Generate grid points for visualization
    xx, yy = generate_grid_point()

    # For creating GIF purpose
    gif_images = []

    # Training
    for step in range(10001):
        # Inference
        if step % 1000 == 0:
            print("Iteration: %d / %d" % (step, 10000))
            with torch.no_grad():

                area = torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float()

                feat_t = g(area)
                logits1_target = c1(feat_t)
                logits2_target = c2(feat_t)

                prob_t1 = torch.softmax(logits1_target, dim=1)
                prob_t2 = torch.softmax(logits2_target, dim=1)

                Z = torch.argmax(logits1_target, dim=1).numpy()

                pred1 = torch.argmax(logits1_target, dim=1).numpy()
                pred2 = torch.argmax(logits2_target, dim=1).numpy()

                ent = cross_entropy(prob_t1, prob_t2, reduce=False) + cross_entropy(prob_t2, prob_t1, reduce=False)
                ent = ent.numpy()

                Z[ent>pi] = n_classes

            Z = Z.reshape(xx.shape)

            pred1 = pred1.reshape(xx.shape)
            pred2 = pred2.reshape(xx.shape)

            condition1 = pred1 == 0
            condition2 = pred2 == 0

            f = plt.figure()
            plt.contourf(xx, yy, Z, [-1,0,1,2,3], colors=['red', 'blue', 'orange', 'black'], alpha=0.3)
            plt.contour(xx, yy, condition1, colors=['red'])
            plt.contour(xx, yy, condition2, linestyles='dashed', colors=['red'])

            plt.scatter(x_s[:, 0][y_s==0], x_s[:, 1][y_s==0], color='red', alpha=0.8)
            plt.scatter(x_s[:, 0][y_s==1], x_s[:, 1][y_s==1], color='blue', alpha=0.8)
            plt.scatter(x_s[:, 0][y_s==2], x_s[:, 1][y_s==2], color='orange', alpha=0.8)
            plt.scatter(x_t[:, 0], x_t[:, 1], color='white', alpha=0.8, marker="s")
            plt.text(4, -8, 'Iter: ' + str(step), fontsize=14, color='#FFD700',
                       bbox=dict(facecolor='dimgray', alpha=0.7))
            plt.axis('off')
            f.savefig(opts.out + '/iter' + str(step) + ".png", bbox_inches='tight',
                        pad_inches=0, dpi=100, transparent=False)
            gif_images.append(imageio.imread(
                                opts.out + '/iter' + str(step) + ".png"))
            plt.close()

        # Step A-1 and A-2
        feat_s = g(X)
        logits1 = c1(feat_s)
        logits2 = c2(feat_s)

        feat_t = g(X_target)
        logits1_target = c1(feat_t)
        logits2_target = c2(feat_t)

        prob_t1 = torch.softmax(logits1_target, dim=1)
        prob_t2 = torch.softmax(logits2_target, dim=1)

        crs = cross_entropy(prob_t1, prob_t2, reduce=False) + cross_entropy(prob_t2, prob_t1, reduce=False)

        loss_crs = weight * -torch.mean(torch.clamp(torch.abs(crs - pi), min=m)[crs<2*pi])

        mask_known = crs < (pi - m)

        prob_t = (prob_t1 + prob_t2) / 2

        ent = cross_entropy(prob_t, prob_t, reduce=False)

        loss_ent = weight * -torch.mean(torch.clamp(torch.abs(ent - pi/2), min=m/2))

        if step <= 100:
            cost1 = nn.CrossEntropyLoss()(logits1, Y)
            cost2 = nn.CrossEntropyLoss()(logits2, Y)
            loss_s = cost1 + cost2
        else:
            loss_s = loss_select(logits1, logits2, Y, 0.2)

        if step > 100:
            all = loss_s + loss_crs  + loss_ent
        else:
            all = loss_s + loss_ent
            
        opt_g.zero_grad()
        opt_c1.zero_grad()
        opt_c2.zero_grad()
        all.backward()
        opt_g.step()
        opt_c1.step()
        opt_c2.step()

        if step > 100:
            # Step B
            feat_t = g(X_target)
            logits1_target = c1(feat_t)
            logits2_target = c2(feat_t)

            prob_t1 = torch.softmax(logits1_target, dim=1)
            prob_t2 = torch.softmax(logits2_target, dim=1)

            loss_crs = - weight *  torch.mean(torch.clamp(cross_entropy(prob_t1, prob_t2, reduce=False) + cross_entropy(prob_t2, prob_t1, reduce=False) , max=2*pi))

            feat_s = g(X)
            logits1 = c1(feat_s)
            logits2 = c2(feat_s)
            
            loss_s = loss_select(logits1, logits2, Y, 0.2)

            all = loss_crs + loss_s

            opt_c1.zero_grad()
            opt_c2.zero_grad()
            all.backward()
            opt_c1.step()
            opt_c2.step()
            
            # Step C
            feat_t = g(X_target)
            logits1_target = c1(feat_t)
            logits2_target = c2(feat_t)

            prob_t1 = torch.softmax(logits1_target, dim=1)
            prob_t2 = torch.softmax(logits2_target, dim=1)

            loss_crs = weight *  torch.mean(torch.clamp(cross_entropy(prob_t1, prob_t2, reduce=False) + cross_entropy(prob_t2, prob_t1, reduce=False) , max=2*pi)[mask_known])

            all = loss_crs
            opt_g.zero_grad()
            all.backward()
            opt_g.step()

        
    # Save GIF
    imageio.mimsave(opts.out + '/train.gif', gif_images, duration=0.8)
    print(f"[Finished]\n-> Please see the {opts.out} folder for outputs.")