


def train(
    dataset_root='datasets/icon4/data/in_memory',
    batch_size=2,
    device='cuda:2',
    num_workers=8,
    image_size=256,
    resume=None,
    output_dir='/tmp/dritpp'
):
    import torch
    import os
    from .dataset import UnpairedIconContourDataset
    from .model import DRIT
    from .saver import Saver
    from torch.utils.data import DataLoader
    
    assert device.startswith('cuda:'), repr(device)
    device = torch.device(device)
    
    no_display_img = False
    lr_policy = 'lambda'
    n_ep_decay = 600
    n_ep = 1200
    d_iter = 3
    

    dataset = UnpairedIconContourDataset(dataset_root, image_size, True, True, (0.9, 1.0))
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    model = DRIT()
    model.setgpu(device.index)
    if resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(resume)
    model.set_scheduler(lr_policy, n_ep_decay, n_ep, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d'%(ep0))

    # saver for display and output
    saver = Saver(
        output_dir,
        os.path.join(output_dir, 'display'),
        os.path.join(output_dir, 'result'),
        'name',
        display_freq=1,
        img_save_freq=5,
        model_save_freq=10,
    )

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != batch_size or images_b.size(0) != batch_size:
                continue

            # input data
            images_a = images_a.cuda(device.index).detach()
            images_b = images_b.cuda(device.index).detach()

            # update model
            if (it + 1) % d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file
            if not no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return

if __name__ == '__main__':
    train()
