"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from .my_dataset import MyDataset
from .utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
import argparse
from torch.autograd import Variable
from torch.utils.data import DataLoader
from .trainer import MUNIT_Trainer, UNIT_Trainer
import torch.backends.cudnn as cudnn
import torch
try:
    from itertools import izip as zip
except ImportError: # will be 3.x series
    pass
import os
import sys
from torch.utils.tensorboard import SummaryWriter
import shutil




def train(
    output_dir,
    device,
    resume=False,
    dataset_root='datasets/icon4/data/in_memory',
    image_size=256,
    num_workers=4,
    batch_size=1,
    max_iter=1000000,
    save_iter=10000,
    sample_iter=1000,
    display_iter=100,
):
    assert image_size == 256
    assert device.startswith('cuda:')
    device = torch.device(device)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(device.index)
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    
    current_dir = os.path.abspath(os.path.dirname(__file__))
    config_path = os.path.join(current_dir, 'configs/edges2handbags_folder.yaml')
    
    opts_trainer = 'MUNIT'
    opts_resume = resume
    
    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(config_path)
    # max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = 'temp/vgg16.pt'
    config['image_save_iter'] = sample_iter
    config['image_display_iter'] = display_iter
    config['snapshot_save_iter'] = save_iter

    # Setup model and data loader
    if opts_trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts_trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    def get_dataloader(dataset, shuffle):
        return DataLoader(dataset=dataset,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          drop_last=True,
                          num_workers=num_workers)
    train_loader_a = get_dataloader(MyDataset(dataset_root, 'contour', image_size, True, True, split=(0.0, 0.9)), True)
    train_loader_b = get_dataloader(MyDataset(dataset_root, 'icon', image_size, True, True, split=(0.0, 0.9)), True)
    test_loader_a = get_dataloader(MyDataset(dataset_root, 'contour', image_size, True, True, split=(0.9, 1.0)), False)
    test_loader_b = get_dataloader(MyDataset(dataset_root, 'icon', image_size, True, True, split=(0.9, 1.0)), False)
    train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(config_path))[0]
    train_writer = SummaryWriter(output_dir)
    output_directory = os.path.join(output_dir + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
    shutil.copy(config_path, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts_resume else 0
    while True:
        for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
                for key, image in write_2images(test_image_outputs, display_size, 'test').items():
                    train_writer.add_image(f'munit/{key}', image, (iterations))
                for key, image in write_2images(train_image_outputs, display_size, 'train').items():
                    train_writer.add_image(f'munit/{key}', image, (iterations))

            if (iterations) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
                for key, image in write_2images(image_outputs, display_size, 'train_current').items():
                    train_writer.add_image(f'munit/{key}', image, (iterations))

            # Save network weights
            if (iterations) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')

