import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

from skimage import io

import time
from datetime import date

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100000, help="number of epochs of training")
parser.add_argument("--max_iter", type=int, default=100000, help="max iteration of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--gen_lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--dis_lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_gpu", type=int, default=1, help="number of gpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--dataset", type=str, default='cifar10', help="type of dataset")
parser.add_argument("--loss_fn", type=str, default='KL', help='loss function for distribution matching')
parser.add_argument("--out_f", default='checkpoints', help='folder to output images and model checkpoints')
parser.add_argument('--device', type=str, default=None, help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument("--save_ckpt", type=int, default=5000, help='point of saving checkpoints')
parser.add_argument("--gf_dim", type=int, default=64)
parser.add_argument("--df_dim", type=int, default=64)
parser.add_argument("--decay_type", type=str, default="expo", help='linear, exponential')
parser.add_argument("--optimal_prob", type=float, default=0.5, help='optimal probability of BCE')
parser.add_argument("--optimal_range", type=float, default=0.0033, help='acceptable range of discriminator loss')
parser.add_argument("--method_type", type=str, default='method1', help='adaptive method1 or adaptive method2')
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False
if opt.device is None:
    device = torch.device('cuda:0' if (torch.cuda.is_available() and opt.n_gpu > 0) else 'cpu')
else:
    device = torch.device(opt.device)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def batch_norm(input_size):
            batch_norm = nn.BatchNorm2d(input_size, momentum=0.9)

            return batch_norm


        self.init_size = opt.img_size //16
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, opt.gf_dim * 8 * self.init_size ** 2))

        def deconv(input_channel, output_channel, stride, kernel, padding):
            return [nn.ConvTranspose2d(input_channel, output_channel, stride=stride, kernel_size=kernel, padding=padding), nn.BatchNorm2d(output_channel), nn.ReLU()]

        def deconv_tanh(input_channel, output_channel, stride, kernel, padding):
            return [nn.ConvTranspose2d(input_channel, output_channel, stride=stride, kernel_size=kernel, padding=padding), nn.Tanh()]


        self.proj_reshape = nn.Sequential(
            nn.BatchNorm2d(opt.gf_dim * 8, momentum=0.9), nn.ReLU())
            
        self.conv_1 = nn.Sequential(*deconv(opt.gf_dim*8, opt.gf_dim*4, stride=2, kernel=4, padding=1))
        self.conv_2 = nn.Sequential(*deconv(opt.gf_dim*4, opt.gf_dim*2, stride=2, kernel=4, padding=1))
        self.conv_3 = nn.Sequential(*deconv(opt.gf_dim*2, opt.gf_dim*1, stride=2, kernel=4, padding=1))
        self.conv_4 = nn.Sequential(*deconv_tanh(opt.gf_dim*1, opt.channels, stride=2, kernel=4, padding=1))


    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], opt.gf_dim * 8, self.init_size, self.init_size)
        out = self.conv_1(out)
        out = self.conv_2(out)
        out = self.conv_3(out)
        img = self.conv_4(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, kernel, padding):
            block = [nn.Conv2d(in_filters, out_filters, stride=stride, kernel_size=kernel, padding=padding), nn.LeakyReLU(0.2, inplace=True)]
            return block

        def discriminator_bn_block(in_filters, out_filters, stride, kernel, padding):
            return [nn.Conv2d(in_filters, out_filters, stride=stride, kernel_size=kernel, padding=padding), nn.BatchNorm2d(out_filters), nn.LeakyReLU(0.2, inplace=True)]

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, opt.df_dim, stride=2, kernel=4, padding=1),
            *discriminator_bn_block(opt.df_dim, opt.df_dim*2, stride=2, kernel=4, padding=1),
            *discriminator_bn_block(opt.df_dim*2, opt.df_dim*4, stride=2, kernel=4, padding=1),
            *discriminator_bn_block(opt.df_dim*4, opt.df_dim*8, stride=1, kernel=4, padding=0),
        )

        if opt.img_size==32:
            ds_size = opt.img_size // 32
        elif opt.img_size==48:
            ds_size = opt.img_size // 16        
        
        self.adv_layer = nn.Sequential(nn.Linear(opt.df_dim*8*ds_size*ds_size, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity



# Loss function
adversarial_loss = torch.nn.BCELoss()
soft_fn = torch.nn.Softmax(dim=1)


distribution_loss = torch.nn.KLDivLoss(reduction='none')

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# DataParallel
generator = generator.to(device)
discriminator = discriminator.to(device)

if (device.type == 'cuda') and (opt.n_gpu > 1):
    netG = nn.DataParallel(generator, list(range(opt.n_gpu)))
    netD = nn.DataParallel(discriminator, list(range(opt.n_gpu)))
else:
    netG = generator
    netD = discriminator

# Initialize weights
netG.apply(weights_init_normal)
netD.apply(weights_init_normal)

if opt.dataset=='cifar10':
    # Configure data loader
    os.makedirs("../data/cifar10", exist_ok=True)
    train_dataset = datasets.CIFAR10("../data/cifar10",
                                     train=True,
                                     download=True,
                                     transform=transforms.Compose(
                                         [transforms.Resize(opt.img_size), transforms.ToTensor(),
                                          transforms.Normalize([0.5], [0.5])])
                                     )
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)

else opt.dataset=='stl10':
    os.makedirs("../data/STL10", exist_ok=True)
    train_dataset = datasets.STL10("../../data/STL10",
                                   split='train',
                                   download=True,
                                   transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(),
                                                                 transforms.Normalize([0.5], [0.5])]),
                                   )
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
    

def KL_matching(source, target):
    distribution_loss = torch.nn.KLDivLoss(reduction='batchmean')
    KL_loss = distribution_loss(torch.log(soft_fn(torch.reshape(source, (opt.batch_size, -1)))),
                                            soft_fn(torch.reshape(target, (opt.batch_size, -1))))
    return KL_loss

def expo_decay(current_steps):
    gamma = math.exp(-(current_steps/opt.max_iter))
    return gamma


def distribution_matching(source, target, method_type, current_steps, decay_type, loss_fn, d_loss, optimal_value):

    dist_diff_value = KL_matching(source, target)
    gamma = expo_decay(current_steps)

    dist_loss = gamma * dist_diff_value
    dist_loss.backward()

    return dist_diff_value, dist_loss, gamma


# Optimizers
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.gen_lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.dis_lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

checkpoint_path = "../ckpt/"
os.makedirs(checkpoint_path, exist_ok=True)


G_losses = []
D_losses = []

optimal_value = -math.log(opt.optimal_prob)

epoch=0

time_begin = time.time()

while epoch < opt.n_epochs:
    for i, (imgs, _) in enumerate(train_dataloader):
        steps = len(train_dataloader) * epoch + i

        start_time = time.time()

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(device)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor)).to(device)
      
        # ---------------------
        #  Train Discriminator
        # ---------------------
            
        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))).to(device)

        # Generate a batch of images
        gen_imgs = netG(z).to(device)

        real_loss = adversarial_loss(netD(real_imgs), valid)
        fake_loss = adversarial_loss(netD(gen_imgs), fake)
        d_loss = real_loss + fake_loss

        d_loss.backward(retain_graph = True)
        optimizer_D.step()


        real_loss = adversarial_loss(netD(real_imgs), valid)
        fake_loss = adversarial_loss(netD(gen_imgs), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward(retain_graph = True)
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(netD(gen_imgs), valid)
        g_loss.backward(retain_graph=True)


        ### adaptive ###
        if abs(d_loss - optimal_value) >  opt.optimal_range :
            state = 1
            dist_diff_value, dist_loss, gamma = distribution_matching(gen_imgs, real_imgs, opt.method_type, steps, opt.decay_type, opt.loss_fn, d_loss, optimal_value)
        else :
            state = 0
            dist_diff_value = KL_matching(gen_imgs, real_imgs)
            gamma = torch.tensor(0.0, requires_grad=False)
            dist_loss = gamma * dist_diff_value


        optimizer_G.step()
    
        elapsed = time.time() - start_time
            
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item()) 

        print("[Iteration %d/%d] [G loss: %f] [D loss: %f] [Dist loss : %f] [Dist value : %f] [Gamma : %f] [State : %d] [Time : %.2f]"
            % ((steps+1), opt.max_iter, g_loss.item(), d_loss.item(), dist_loss.item(), dist_diff_value.item(), gamma, state, elapsed))

        ### adaptive ###
        if (steps + 1) % opt.save_ckpt == 0:
            torch.save(netG.state_dict(),
                    checkpoint_path + 'adaptive_distribution_netG_%d.pth' % (steps+1))
            torch.save(netD.state_dict(),
                    checkpoint_path + 'adaptive_distribution_netD_%d.pth' % (steps+1))

        if (steps + 1) == opt.max_iter :
            break
    if (steps + 1) == opt.max_iter :
        break
        
    epoch += 1


print('Time of training-{}'.format((time.time() - time_begin)))
        

plt.figure(figsize=(5,5))
plt.tight_layout()
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Discriminator")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.savefig("adaptive distribution loss while trainig_%s_%s_%s_max_iter_%d_%s_%s.png" % (opt.dataset, opt.loss_fn, opt.decay_type, opt.max_iter, opt.method_type, opt.optimal_range))
plt.savefig("adaptive distribution loss while trainig_%s_%s_%s_max_iter_%d_%s_%s.pdf" % (opt.dataset, opt.loss_fn, opt.decay_type, opt.max_iter, opt.method_type, opt.optimal_range))
