

import os
import torch
import torch.nn as nn
import sys
sys.path.append('/home/balitskiy/projects/ResNet18_cl/src')

from defaultConfigQB import _EXPs as config_EXP
from defaultConfigQB import update_config_EXP

from defaultConfigOneQB import _QB as config_QB
from defaultConfigOneQB import update_config_QB


from NewQuantizationBlock import QuantizationBlock

from q_utils import parse_args, _init_models,_load_pretrained_models,\
    _init_bin_ops_dict,_init_dataloader, create_loggers_writers,\
    create_EXP_NAMES

from q_function import validate, collect_statistics

import random




def main():
    args = parse_args()

    # logger.info(pprint.pformat(args))
    # logger.info(pprint.pformat(config))
    # model = models.get_face_alignment_net(config)

    create_EXP_NAMES(config_EXP)
    loggers_dict, writers_dict = create_loggers_writers(config_EXP)
    all_models, qb_configs = _init_models(config_EXP, loggers_dict)

    _load_pretrained_models(all_models, config_EXP, loggers_dict)

    bin_ops_dict = _init_bin_ops_dict(all_models, config_EXP, loggers_dict)
    dataloaders_dict, dataset_dict = _init_dataloader(config_EXP,loggers_dict)
    # optimizers_dict, schedulers_dict = _init_optimizer(all_models, config_EXP, loggers_dict)
    if 'validate_FULL_w32_asym_a32' in config_EXP.EXP_NAMES:
        model_full = all_models['validate_FULL_w32_asym_a32']
        model_full.to(config_EXP.DEVICES[0])
    else:
        model_full = None

    for exp_num, exp_name in enumerate(config_EXP.EXP_NAMES):
        if exp_name == 'validate_FULL_w32_asym_a32':
            continue
        random.seed(100)
        print("EXPERIMENT {} STARTS:".format(exp_name))
        device = config_EXP.DEVICES[exp_num]
        model = all_models[exp_name]
        model = model.to(device)
        criterion = torch.nn.CrossEntropyLoss().to(device)
        binOp = bin_ops_dict[exp_name]
        dataloaders, datasets = dataloaders_dict[exp_name], dataset_dict[exp_name]
        logger = loggers_dict[exp_name]
        if config_EXP.OPERATION[exp_num] == 'validate':
            validate(model, model_full, binOp, dataloaders[0], criterion, device,
                         logger)

            del model

        if config_EXP.OPERATION[exp_num] == 'collect_statistics':
            collect_statistics(model, dataloaders[0], device)
            del model
        # if config_EXP.OPERATION[exp_num] == 'train':
        #     path_save_checkpoints = os.path.join(config_EXP.OUTPUT_BASE_PATH,
        #                                 config_EXP.TRAIN.SAVE_CHECKPOINT_PATH[exp_num])
        #     begin_epoch = config_EXP.TRAIN.BEGIN_EPOCH[exp_num]
        #     last_epoch = config_EXP.TRAIN.END_EPOCH[exp_num]
        #     optimizer = optimizers_dict[exp_name]
        #     lr_scheduler = schedulers_dict[exp_name]
        #     train_dataloader = dataloaders[0]
        #     val_dataloader = dataloaders[1]
        #
        #     train(model, binOp, train_dataloader, val_dataloader, exp_name,
        #           model_config, optimizer, lr_scheduler, begin_epoch,
        #           last_epoch, path_save_checkpoints, device, logger)




if __name__ == '__main__':
    main()