import os
import argparse
import yaml
import torch
import ei.patched

from .solver import Solver
from torch.backends import cudnn
from .data_loader import get_loader


def make_train_directory(config):
    # Create directories if not exist.
    os.makedirs(config['TRAINING_CONFIG']['TRAIN_DIR'], exist_ok=True)
    os.makedirs(os.path.join(config['TRAINING_CONFIG']['TRAIN_DIR'], config['TRAINING_CONFIG']['LOG_DIR']), exist_ok=True)
    os.makedirs(os.path.join(config['TRAINING_CONFIG']['TRAIN_DIR'], config['TRAINING_CONFIG']['SAMPLE_DIR']), exist_ok=True)
    os.makedirs(os.path.join(config['TRAINING_CONFIG']['TRAIN_DIR'], config['TRAINING_CONFIG']['RESULT_DIR']), exist_ok=True)
    os.makedirs(os.path.join(config['TRAINING_CONFIG']['TRAIN_DIR'], config['TRAINING_CONFIG']['MODEL_DIR']), exist_ok=True)


def main(
    output_dir,
    device,
    batch_size=8,
    num_workers=4,
    dataset_root='datasets/icon4/data/in_memory',
    image_size=256,
    end_iter=600000,
    log_int=200,
    sample_int=1000,
    save_int=1000,
):
    assert image_size == 256
    assert device.startswith('cuda:')

    config_path = os.path.join(os.path.dirname(__file__), 'config.yml')
    config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
    
    cuda_index = torch.device(device).index
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cuda_index)
    
    # Overwrite config
    config['TRAINING_CONFIG']['TRAIN_DIR'] = output_dir
    config['TRAINING_CONFIG']['BATCH_SIZE'] = batch_size
    config['TRAINING_CONFIG']['NUM_WORKER'] = num_workers
    config['TRAINING_CONFIG']['IMG_DIR'] = dataset_root
    config['MODEL_CONFIG']['IMG_SIZE'] = image_size
    config['TRAINING_CONFIG']['END_ITER'] = end_iter
    config['TRAINING_CONFIG']['SAVE_STEP'] = save_int
    config['TRAINING_CONFIG']['SAMPLE_STEP'] = sample_int
    config['TRAINING_CONFIG']['LOG_STEP'] = log_int
    config['TRAINING_CONFIG']['GPU'] = cuda_index
    
    
    make_train_directory(config)

    assert config['TRAINING_CONFIG']['MODE'] in ['train', 'test']

    cudnn.benchmark = True
    solver = Solver(config, get_loader(config))
    print('{} is started'.format(config['TRAINING_CONFIG']['MODE']))
    if config['TRAINING_CONFIG']['MODE'] == 'train':
        solver.train()
    elif config['TRAINING_CONFIG']['MODE'] == 'test':
        solver.test()
    print('{} is finished'.format(config['TRAINING_CONFIG']['MODE']))


if __name__ == '__main__':
    import fire
    fire.Fire(main)
