import os
import sys
import time
import shutil
from argparse import ArgumentParser
from matplotlib.pyplot import violinplot
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import ei.patched

from .network import Generator, Discriminator
from .dataset import IconDataset


def train(
    output_dir,
    device,
    dataset_root='datasets/icon4/data/in_memory',
    num_workers=4,
    batch_size=64,
    image_size=128,
    pad_ratio=8,
    lrG=5e-5,
    lrD=2e-4,
    epochs=2000,
    log_int=200,
    save_int=3000,
    sample_int=1000,
    dummy_input=False,
    resume=False
):
    dataset_root = os.path.abspath(dataset_root)
    
    os.makedirs(output_dir, exist_ok=True)
    
    save_path = os.path.join(output_dir, 'weights')
    sample_path = os.path.join(output_dir, 'samples')

    # helpers

    def clear_dir(path):
        if os.path.exists(path):
            shutil.rmtree(path)
        os.mkdir(path)

    # hinge loss
    def dis_loss(D, real, fake):
        d_out_real = D(real)
        d_out_fake = D(fake.detach())
        loss_real = F.relu(1.0 - d_out_real).mean()
        loss_fake = F.relu(1.0 + d_out_fake).mean()
        return loss_real + loss_fake

    def gen_loss(D, fake):
        d_out = D(fake)
        return -(d_out).mean()

    device = torch.device(device)
    print('Device:', device)

    # load dataset
    # construct networks
    G = Generator(ch_style=3, ch_content=1).to(device)
    Ds = Discriminator(3+3).to(device)
    Dc = Discriminator(3+1).to(device)

    if dummy_input: # debug purpose
        BATCH_SIZE = 16
        s1 = torch.randn(BATCH_SIZE, 3, image_size, image_size).to(device)
        contour = torch.randn(BATCH_SIZE, 1, image_size, image_size).to(device)
        fake = G(s1, contour)
        print('fake.shape', fake.shape)
        ds_out = Ds(torch.cat([fake, s1], dim=1))
        dc_out = Dc(torch.cat([fake, contour], dim=1))
        print('Ds_out.shape', ds_out.shape)
        print('Dc_out.shape', dc_out.shape)
        sys.exit(0)

    optimG = optim.Adam(G.parameters(), lr=lrG, betas=(0, 0.999))
    optimDc = optim.Adam(Dc.parameters(), lr=lrD, betas=(0, 0.999))
    optimDs = optim.Adam(Ds.parameters(), lr=lrD, betas=(0, 0.999))

    # prepare dataset
    dataset = IconDataset(root=dataset_root, image_size=image_size, pad_ratio=pad_ratio)
    dataloader = DataLoader(dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    # sample fixed inputs
    for s1, s2, s3, contour in dataloader:
        break

    fixed_s1 = s1.to(device)
    fixed_contour = contour.to(device)
    vutils.save_image(fixed_s1.detach().cpu(), os.path.join(output_dir, 's1.png'), padding=0)
    vutils.save_image(s2.detach().cpu(), os.path.join(output_dir, 's2.png'), padding=0)
    vutils.save_image(s3.detach().cpu(), os.path.join(output_dir, 's3.png'), padding=0)
    vutils.save_image(fixed_contour.detach().cpu(), os.path.join(output_dir, 'contour.png'), padding=0)

    # training loop
    print('training...')
    writer = SummaryWriter(output_dir)
    global_step = 0
    start_epoch = 0
    timestamp = time.time()
    # check resume

    if resume:
        path = os.path.join(
            save_path,
            'latest_{}.pth'
        )

        G.load_state_dict(torch.load(path.format('G'), map_location=device))
        Ds.load_state_dict(torch.load(path.format('Ds'), map_location=device))
        Dc.load_state_dict(torch.load(path.format('Dc'), map_location=device))
        optimG.load_state_dict(torch.load(path.format('optimG'), map_location=device))
        optimDs.load_state_dict(torch.load(path.format('optimDs'), map_location=device))
        optimDc.load_state_dict(torch.load(path.format('optimDc'), map_location=device))
        global_step, start_epoch = torch.load(path.format('state'))
        print('resumed from Epoch: {:04d} Step: {:07d}'.format(start_epoch, global_step))
    else:
        clear_dir(sample_path)
        clear_dir(save_path)
    iter_epochs = tqdm(range(start_epoch, epochs))
    
    for epoch in iter_epochs:
        for s1, s2, s3, contour in dataloader:
            # s1 s2 are in same cluster in lab space
            # s3 contour are paired icon and it's contour

            s1 = s1.to(device)
            s2 = s2.to(device)
            s3 = s3.to(device)
            contour = contour.to(device)

            fake = G(s1, contour)
            style_fake = torch.cat([fake, s2], dim=1)
            style_real = torch.cat([s1, s2], dim=1)
            content_fake = torch.cat([fake, contour], dim=1)
            content_real = torch.cat([s3, contour], dim=1)
            
            # update style discriminator
            optimDs.zero_grad()
            Ds_loss = dis_loss(Ds, style_real, style_fake)
            Ds_loss.backward()
            optimDs.step()

            # update content discriminator
            optimDc.zero_grad()
            Dc_loss = dis_loss(Dc, content_real, content_fake)
            Dc_loss.backward()
            optimDc.step()

            # update generator
            optimG.zero_grad()
            Gs_loss = gen_loss(Ds, style_fake)
            Gc_loss = gen_loss(Dc, content_fake)
            G_loss = Gs_loss + Gc_loss
            G_loss.backward()
            optimG.step()

            # log losses

            if global_step % log_int == 0:
                log_dict = {
                    'D_style_loss': Ds_loss.item(),
                    'D_content_loss': Dc_loss.item(),
                    'G_style_loss': Gs_loss.item(),
                    'G_content_loss': Gc_loss.item(),
                    'G_loss': G_loss.item(),
                    'global_step': global_step
                }
                for key, value in log_dict.items():
                    writer.add_scalar(f'icon/{key}', value, global_step)
                iter_epochs.set_postfix({k: (f'{v:.4f}' if isinstance(v, float) else v)
                                         for k, v in log_dict.items()})
                # curTime = time.time()
                # print('Epoch: {:04d} Step: {:07d} Elapsed Time: {:.3f}s Ds: {:.5f} Dc: {:.5f} G: {:.5f}'.format(
                #     epoch, global_step,
                #     curTime - timestamp,
                #     Ds_loss.item(),
                #     Dc_loss.item(),
                #     G_loss.item(),
                # ))
                # timestamp = curTime

            if global_step % save_int == 0:
                save_idx = global_step // save_int
                for prefix in ['{:05d}'.format(save_idx), 'latest']:
                    path = os.path.join(
                        save_path,
                        prefix
                    )
                    path += '_{}.pth'
                    torch.save(G.state_dict(), path.format('G'))
                    torch.save(Ds.state_dict(), path.format('Ds'))
                    torch.save(Dc.state_dict(), path.format('Dc'))
                    torch.save(optimG.state_dict(), path.format('optimG'))
                    torch.save(optimDs.state_dict(), path.format('optimDs'))
                    torch.save(optimDc.state_dict(), path.format('optimDc'))
                    torch.save((global_step, epoch), path.format('state'))
            
            if global_step % sample_int == 0:
                G.eval()
                with torch.no_grad():
                    fixed_fake = G(fixed_s1, fixed_contour)
                G.train()
                # path = os.path.join(sample_path, '{:05d}.png'.format(save_idx))
                # vutils.save_image(fixed_fake.detach().cpu(), path, padding=0)
                
                grid = vutils.make_grid(fixed_fake.detach().cpu())
                writer.add_image('icon/fixed_samples', grid, global_step)
                
                # print('log {:05d} saved'.format(save_idx))

            global_step += 1
