# args.dataset_path = '/path/to/images/directory'
# args.save_result_path = '/path/to/save'
# args.num_dataset_test = 1000
# args.network_model = 'LTBC'
# args.batchsize = 30
# args.size_image = 224
# args.augmentation = True
# args.size_image_augmentation = 256
# args.gpu = -1
# args.max_epoch = 1000
# args.save_result_iteration = 1000
# args.random_seed_test = 0
# args.path_pretrained_model = None
# args.disable_ltbc_global = False
# args.line_drawing_mode = 'otsu_threshold'
# args.optimizer_adam_alpha = 0.0001
# args.weight_decay = 0.0001
# args.path_tag_list = '/path/to/label_ID_list.json'
# args.path_tag_list_each_image = '/path/to/label_ID_list_for_each_image.json'
# args.blend_mse_color = 1.0
# args.alpha_ltbc_classification = 0.00333333333
# args.ltbc_classification_num_output_list = [512, 428]
# args.use_histogram_network = True
# args.num_bins_histogram = 6
# args.threshold_histogram_palette = 0.0
# args.use_multidimensional_histogram = True
# args.reinput_mode = None
# args.loss_blend_ratio_reinput = []
# args.separate_backward_reinput = False
# args.separate_model_reinput = False
# args.use_residual_reinput = False
# args.ltbc_classification_loss_function = 'multi_label'
# args.max_pixel_drawing = 15
# args.max_size_pixel_drawing = 1
# args.loss_type = 'Lab'
# args.mse_loss_mode = 'color_space'
# args.use_adversarial_network = True
# args.blend_adversarial_generator = 1.0
# args.discriminator_first_pooling_size = 2
# args.log_interval = 200
# args.verbose = False

import sys
import os
import json
import glob
import typing
import chainer
from chainer.training import extensions
import numpy
ROOT_PATH = os.path.join(os.path.dirname(__file__), "..")
sys.path.append(ROOT_PATH)
import comicolorization
from comicolorization.extensions import SaveRawImageExtension, SaveGeneratedImageExtension

parser = comicolorization.utility.config.get_train_parser()

args = parser.parse_args()
print(args)

use_classification = 0.00333333333 is not None

model = comicolorization.models.Ltbc(
    use_global=not False,
    use_classification=0.00333333333 is not None,
    classification_num_output_list=[512, 428],
    use_histogram=True,
    use_multidimensional_histogram=True,
    num_bins_histogram=6,
    threshold_histogram_palette=0.0,
    reinput_mode=None,
    loss_type='Lab',
)
model_reinput_list = []
discriminator = comicolorization.models.Discriminator(
    size=224,
    first_pooling_size=2,
)

# make dataset
paths = glob.glob("{}/*".format('/path/to/images/directory'))
random_state = numpy.random.RandomState(0)
paths = random_state.permutation(paths)

datasets = comicolorization.utility.dataset.choose_dataset(
    paths=paths,
    num_dataset_test=1000,
    loss_type='Lab',
    augmentation=True,
    size_image_augmentation=[256, 256],
    size_image=[224, 224],
    use_ltbc_classification=0.00333333333 is not None,
    path_tag_list='/path/to/label_ID_list.json',
    path_tag_list_each_image='/path/to/label_ID_list_for_each_image.json',
    line_drawing_mode='otsu_threshold',
    max_pixel_drawing=15,
    max_size_pixel_drawing=1,
    use_binarization_dataset=False,
)
train_dataset = datasets['train']
test_dataset = datasets['test']
train_for_evaluate_dataset = datasets['train_for_evaluate']

train_iterator = chainer.iterators.MultiprocessIterator(
    train_dataset,
    batch_size=30,
    repeat=True,
    shuffle=False,
)
test_iterator = chainer.iterators.MultiprocessIterator(
    test_dataset,
    batch_size=30,
    repeat=False,
    shuffle=False,
)
train_for_evaluate_iterator = chainer.iterators.MultiprocessIterator(
    train_for_evaluate_dataset,
    batch_size=30,
    repeat=False,
    shuffle=False,
)

range_input = train_dataset.get_input_range()
range_input_luminance = train_dataset.get_input_luminance_range()
range_output_luminance = train_dataset.get_output_range()[0]

# make loss
loss_maker = comicolorization.loss.LossMaker(
    args=args,
    model=model,
    model_reinput_list=model_reinput_list,
    range_input_luminance=range_input_luminance,
    range_output_luminance=range_output_luminance,
    discriminator=discriminator
)


# make trainer
def make_optimizer(_model):
    _optimizer = chainer.optimizers.Adam(alpha=0.0001)
    _optimizer.setup(_model)

    _optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001))

    return _optimizer


optimizer = make_optimizer(model)
discriminator_optimizer = make_optimizer(discriminator)

main_lossfun = lambda loss_detail: loss_detail['sum_loss']
reinput_lossfun = None
reinput_optimizer = None

discriminator_lossfun = lambda loss_detail: loss_detail['sum_loss_discriminator']

updater = comicolorization.updater.MultiUpdater(
    args=args,
    loss_maker=loss_maker,
    main_optimizer=optimizer,
    main_lossfun=main_lossfun,
    reinput_optimizer=reinput_optimizer,
    reinput_lossfun=reinput_lossfun,
    iterator=train_iterator,
    device=-1,
    discriminator_optimizer=discriminator_optimizer,
    discriminator_lossfun=discriminator_lossfun
)
trainer = chainer.training.Trainer(updater, (1000, 'epoch'), out='/path/to/save')


def save_json(filename, obj):
    json.dump(obj, open(filename, 'w'), sort_keys=True, indent=4)


input_image_mode = 'gray' if len(range_input) != 3 else 'Lab'

train_images = [train_for_evaluate_dataset[i] for i in range(10)]
train_color_images = numpy.concatenate([numpy.expand_dims(images[0], axis=0) for images in train_images])
train_gray_images = numpy.concatenate([numpy.expand_dims(images[1], axis=0) for images in train_images])
train_rgb_images = numpy.concatenate([numpy.expand_dims(images[2], axis=0) for images in train_images])
train_extend_generated_image = SaveGeneratedImageExtension(train_gray_images, train_rgb_images, model, prefix_directory='train_{.updater.iteration}', image_mode='Lab')
train_extend_gray_image = SaveRawImageExtension(train_gray_images, prefix_directory='train_gray_image', prefix_filename='gray_', image_mode=input_image_mode, linedrawing='otsu_threshold')
train_extend_raw_image = SaveRawImageExtension(train_color_images, prefix_directory='train_raw_image', prefix_filename='color_', image_mode='Lab')

test_images = [test_dataset[i] for i in range(10)]
test_color_images = numpy.concatenate([numpy.expand_dims(images[0], axis=0) for images in test_images])
test_gray_images = numpy.concatenate([numpy.expand_dims(images[1], axis=0) for images in test_images])
test_rgb_images = numpy.concatenate([numpy.expand_dims(images[2], axis=0) for images in test_images])
test_extend_generated_image = SaveGeneratedImageExtension(test_gray_images, test_rgb_images, model, prefix_directory='test_{.updater.iteration}', image_mode='Lab')
test_extend_gray_image = SaveRawImageExtension(test_gray_images, prefix_directory='test_gray_image', prefix_filename='gray_', image_mode=input_image_mode, linedrawing='otsu_threshold')
test_extend_raw_image = SaveRawImageExtension(test_color_images, prefix_directory='test_raw_image', prefix_filename='color_', image_mode='Lab')

trainer.extend(extensions.dump_graph('main/sum_loss', out_name='main_graph.dot'))

trainer.extend(extensions.snapshot_object(args.__dict__, 'argument.json', savefun=save_json), invoke_before_training=True)
trainer.extend(train_extend_gray_image, invoke_before_training=True, trigger=lambda _: False)
trainer.extend(train_extend_raw_image, invoke_before_training=True, trigger=lambda _: False)
trainer.extend(test_extend_gray_image, invoke_before_training=True, trigger=lambda _: False)
trainer.extend(test_extend_raw_image, invoke_before_training=True, trigger=lambda _: False)

save_interval = (1000, 'iteration')
trainer.extend(train_extend_generated_image, trigger=save_interval)
trainer.extend(test_extend_generated_image, trigger=save_interval)
trainer.extend(extensions.snapshot_object(model, '{.updater.iteration}.model'), trigger=save_interval)


num_reinput = len([])

report_target = ['epoch', 'iteration']
for evaluater_name in ['', 'validation/', 'validation/train/']:
    for model_name in ['main/'] + ['discriminator/'] + ['reinput{}/'.format(i) for i in range(num_reinput)]:
        for reinput_name in [''] + ['reinput{}/'.format(i) for i in range(num_reinput)]:
            for loss_name in loss_maker.get_loss_names():
                report_target.append(evaluater_name + model_name + reinput_name + loss_name)

log_interval = (200, 'iteration')
targets = {'main': model}
trainer.extend(extensions.Evaluator(test_iterator, target=targets, eval_func=loss_maker.loss_test, device=-1), trigger=log_interval)
trainer.extend(extensions.Evaluator(train_for_evaluate_iterator, target=targets, eval_func=loss_maker.loss_test, device=-1), name='validation/train', trigger=log_interval)
trainer.extend(extensions.LogReport(trigger=log_interval, log_name="log.txt"))
trainer.extend(extensions.ProgressBar(update_interval=10))

trainer.run()