

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# _______import model___________
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

# ______________qblock import_______________
import os
import logging
from defaultConfigQB import _EXPs as config_EXP
from defaultConfigQB import update_config_EXP

from defaultConfigOneQB import _QB as config_QB
from defaultConfigOneQB import update_config_QB
from xnor_util import BinOp
import copy
import argparse
from resnet18 import ResNet, resnet18, resnet34
from vgg import VGG, vgg16, vgg16_bn

def parse_args():
    parser = argparse.ArgumentParser(description='Train Face Alignment')

    # parser.add_argument('--cfg', help='experiment configuration filename',
    #                     required=True, type=str)
    parser.add_argument('--cfg-exp', help='experiment configuration filename',
                        required=True, type=str)
    # parser.add_argument('--model-file', help='model parameters',
    #                     required=True, type=str)

    args = parser.parse_args()
    print('ARGS {}\n\n\n'.format(args))
    # update_config(config, args)
    update_config_EXP(config_EXP, args)
    return args


def save_checkpoint(states, is_best,
                    output_dir, exp_name):
    epoch = states["epoch"]
    filename = 'checkpoint_{}_{}.pth'.format(exp_name, epoch)
    torch.save(states, os.path.join(output_dir, filename))

    if is_best and 'state_dict' in states.keys():
        filename = 'checkpoint_{}_best.pth'.format(exp_name)
        torch.save(states['state_dict'].module, os.path.join(filename))

def _init_models(config_EXP, loggers_dict):
    all_models = dict()
    qb_configs = dict()
    # INIT MODELS (every model configs : qb_config, model_config(external))

    for index, qb_config_file in enumerate(config_EXP.QB_CONFIGS):
        logger = loggers_dict[config_EXP.EXP_NAMES[index]]
        if qb_config_file is None:
            msg = "QB_config: {} doesn't exist: Model for experiment" \
                  " {} is not initialized!!".format(qb_config_file,
                                                    config_EXP.EXP_NAMES[index])
            print(msg)
            logger.info(msg)
            continue

        path_to_qb_config = os.path.join(config_EXP.INPUT_BASE_PATH, qb_config_file)
        temp_qb_config = copy.deepcopy(config_QB)
        update_config_QB(temp_qb_config, path_to_qb_config)

        merge_exp_qblock_config(config_EXP, index,temp_qb_config)
        if config_EXP.NN_TYPE == 'resnet18':
            model = resnet18(temp_qb_config)
        if config_EXP.NN_TYPE == 'resnet34':
            model = resnet34(temp_qb_config)
        if config_EXP.NN_TYPE == 'vgg16':
            model = vgg16(temp_qb_config)
        if config_EXP.NN_TYPE == 'vgg16_bn':
            model = vgg16_bn(temp_qb_config)

        all_models[config_EXP.EXP_NAMES[index]] = model
        qb_configs[config_EXP.EXP_NAMES[index]] = temp_qb_config

        # create model configs (update initial authors configs)
        # t_model_config = copy.deepcopy(config)

        # t_model_config.DATASET.DATASET = config_EXP.DATASET_TYPE
        # t_model_config.DATASET.TRAIN_SET = config_EXP.TRAIN_DATA
        # t_model_config.DATASET.ROOT = os.path.join(config_EXP.SOURCE,
        #                                                 t_model_config.DATASET.ROOT)
        # t_model_config.DATASET.TEST_SET = config_EXP.TEST_DATA

        # if config_EXP.OPERATION[index] == 'validate' \
        #         or config_EXP.OPERATION[index] == 'collect_statistics':
        #     t_model_config.TEST.BATCH_SIZE_PER_GPU = 4
        #     t_model_config.MODEL.CHECKPOINT = os.path.join(config_EXP.INPUT_BASE_PATH,
        #                                                    config_EXP.CHECKPOINTS[index])
        #
        #
        # if config_EXP.OPERATION[index] == 'train':
        #     t_model_config.MODEL.CHECKPOINT = os.path.join(config_EXP.INPUT_BASE_PATH,
        #                                                    config_EXP.CHECKPOINTS[index])
        #     t_model_config.TRAIN.BATCH_SIZE_PER_GPU = config_EXP.TRAIN.BATCH_SIZE_PER_GPU[index]
        #     t_model_config.TRAIN.SHUFFLE = config_EXP.TRAIN.SHUFFLE[index]
        #     t_model_config.TRAIN.BEGIN_EPOCH = config_EXP.TRAIN.BEGIN_EPOCH[index]
        #     t_model_config.TRAIN.END_EPOCH = config_EXP.TRAIN.END_EPOCH[index]
        #     t_model_config.TRAIN.RESUME = config_EXP.TRAIN.RESUME[index]
        #     t_model_config.TRAIN.OPTIMIZER = config_EXP.TRAIN.OPTIMIZER[index]
        #     t_model_config.TRAIN.MOMENTUM = config_EXP.TRAIN.MOMENTUM[index]
        #     t_model_config.TRAIN.NESTEROV = config_EXP.TRAIN.NESTEROV[index]
        #     t_model_config.TRAIN.LR = config_EXP.TRAIN.LR[index]
        #     t_model_config.TRAIN.WD = config_EXP.TRAIN.WD[index]
        #     t_model_config.GPUS = (config_EXP.DEVICES[index],)

        # model_configs[config_EXP.EXP_NAMES[index]] = t_model_config
        msg = 'Model {} is built!'.format(temp_qb_config.EXP_NAME)
        print(msg)
        logger.info(msg)
    return all_models, qb_configs


def _match_checkpoint(model, checkpoint_file):

    checkpoint_state_dict = torch.load(checkpoint_file)
    if 'initial' in checkpoint_file:
        if 'resnet' in checkpoint_file:
            model_state_dict = model.state_dict()
            checkpoint_state_dict = {k: v for k, v in checkpoint_state_dict.items()
                                     if k in model_state_dict.keys()}
            model_state_dict.update(checkpoint_state_dict)
        else:
            checkpoint_state_dict_backup = checkpoint_state_dict
            model_state_dict = model.state_dict()
            # print("model_state_dict", model_state_dict.keys())
            # checkpoint_state_dict = {k: v for k, v in checkpoint_state_dict.items()
            #                          if k in model_state_dict.keys()}
            # print("print checkpoint",checkpoint_state_dict)
            counter_qblock = 1
            for name in checkpoint_state_dict_backup.keys():
                if 'features' in name and 'weight' in name:
                    layer_num_init = int(name[9: name.find(".", 9)])
                    model_state_dict['features.' + str(layer_num_init + counter_qblock) + '.weight'] = \
                        checkpoint_state_dict_backup[name]
                    model_state_dict['features.' + str(layer_num_init + counter_qblock) + '.bias'] = \
                        checkpoint_state_dict_backup['features.' + str(layer_num_init) + '.bias']
                    counter_qblock += 1
            model_state_dict['classifier.0.weight'] = checkpoint_state_dict_backup['classifier.0.weight']
            model_state_dict['classifier.0.bias'] = checkpoint_state_dict_backup['classifier.0.bias']

            model_state_dict['classifier.4.weight'] = checkpoint_state_dict_backup['classifier.3.weight']
            model_state_dict['classifier.4.bias'] = checkpoint_state_dict_backup['classifier.3.bias']

            model_state_dict['classifier.8.weight'] = checkpoint_state_dict_backup['classifier.6.weight']
            model_state_dict['classifier.8.bias'] = checkpoint_state_dict_backup['classifier.6.bias']
    else:
        model_state_dict = torch.load(checkpoint_file)

    return model_state_dict


def _load_pretrained_models(all_models, config_EXP, loggers_dict):
    for index, exp_name in enumerate(config_EXP.EXP_NAMES):
        t_check = os.path.join(config_EXP.INPUT_BASE_PATH,
                                               config_EXP.CHECKPOINTS[index])
        model_checkpoint = _match_checkpoint(all_models[exp_name], t_check)

        all_models[exp_name].load_state_dict(model_checkpoint)
        msg = "Model for experiment {} loaded : checkpoint file {} ".format(
            exp_name, t_check)
        print(msg)
        loggers_dict[exp_name].info(msg)


def _init_bin_ops_dict(all_models, config_EXP,loggers_dict):
    bin_ops_dict = dict()

    for index, model_name in enumerate(config_EXP.EXP_NAMES):
        bin_ops_dict[model_name] = BinOp(all_models[model_name],
                             config_EXP.WEIGHT.NUM_BITS[index],
                             config_EXP.WEIGHT.Q_ALGORITHM[index])
        msg = "BinOp for experiment {}, number of bits:{}, " \
              "quantization methods {}".format(config_EXP.EXP_NAMES[index],
                                               config_EXP.WEIGHT.NUM_BITS[index],
                                               config_EXP.WEIGHT.Q_ALGORITHM[index])
        print(msg)
        loggers_dict[model_name].info(msg)
    return bin_ops_dict


def _init_dataloader(config_EXP, loggers_dict):
    dataloader_dict = dict()
    dataset_dict = dict()
    for index, exp_name in enumerate(config_EXP.EXP_NAMES):
        logger = loggers_dict[exp_name]

        dataloader_dict[exp_name] = []
        dataset_dict[exp_name] = []

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if config_EXP.OPERATION[index] == 'validate' :
            val_dataset = datasets.ImageFolder(config_EXP.TEST_DATA, transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize
                ]))
            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=config_EXP.VALIDATE.BATCH_SIZE_PER_GPU, shuffle=False,
                num_workers=4, pin_memory=True)

            dataset_dict[exp_name].append(val_dataset)
            dataloader_dict[exp_name].append(val_loader)

        elif config_EXP.OPERATION[index] == 'collect_statistics':

            train_dataset = datasets.ImageFolder(
                config_EXP.TRAIN_DATA,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize]))

            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=config_EXP.VALIDATE.BATCH_SIZE_PER_GPU, shuffle=True,
                num_workers=4, pin_memory=True)


            dataset_dict[exp_name].append(train_dataset)
            dataloader_dict[exp_name].append(train_loader)


        elif config_EXP.OPERATION[index]  == 'train': #TODO: ADD TRAIN

            dataloader_dict[exp_name].append(None)
            dataloader_dict[exp_name].append(None)

        msg = "Dataloader for {} initialized".format(exp_name)
        print(msg)
        logger.info(msg)
    return dataloader_dict, dataset_dict


# def _init_optimizer(all_models, config_EXP, loggers_dict):
#     optimizers_dict = dict()
#     schedulers_dict = dict()
#     for num_EXP, exp_name in enumerate(config_EXP.EXP_NAMES):
#         logger = loggers_dict[exp_name]
#         if config_EXP.OPERATION[num_EXP] == 'train' and \
#                 config_EXP.OPTIMIZER[num_EXP] is not None:
#             if config_EXP.TRAIN.OPTIMIZER[num_EXP] == 'sgd':
#                 optimizers_dict[exp_name] = optim.SGD(
#                     filter(lambda p: p.requires_grad, all_models[exp_name].parameters()),
#                     lr=config_EXP.TRAIN.LR[num_EXP],
#                     momentum=config_EXP.TRAIN.MOMENTUM[num_EXP],
#                     weight_decay=config_EXP.TRAIN.WD[num_EXP],
#                     nesterov=config_EXP.TRAIN.NESTEROV[num_EXP]
#                 )
#
#                 msg = "OPTIMIZER TYPE: {}, TRAIN_LR: {}, LR_FACTOR: {}," \
#                       " MOMENTUM: {}, WD: {}, NESTEROV: {}".format(config_EXP.TRAIN.OPTIMIZER[num_EXP],
#                                                                    config_EXP.TRAIN.LR[num_EXP],
#                                                                    config_EXP.TRAIN.LR_FACTOR[num_EXP],
#                                                                    config_EXP.TRAIN.MOMENTUM[num_EXP],
#                                                                    config_EXP.TRAIN.WD[num_EXP],
#                                                                    config_EXP.TRAIN.NESTEROV[num_EXP])
#                 print(msg)
#                 logger.info(msg)
#             elif config_EXP.TRAIN.OPTIMIZER == 'adam':
#                 optimizers_dict[exp_name] = optim.Adam(
#                     filter(lambda p: p.requires_grad, all_models[exp_name].parameters()),
#                     lr=config_EXP.TRAIN.LR[num_EXP]
#                 )
#
#                 msg = "OPTIMIZER TYPE: {}, TRAIN_LR: {}".format(config_EXP.TRAIN.OPTIMIZER[num_EXP],
#                                                                    config_EXP.TRAIN.LR[num_EXP])
#                 print(msg)
#                 logger.info(msg)
#
#             else:
#                 print('{} for experiment {} is NOT supported, ' \
#                       'choose another one'.format(config_EXP.TRAIN.OPTIMIZER[num_EXP],
#                                                   exp_name))
#             # schedulers_dict init
#             if isinstance(config_EXP.TRAIN.LR_STEP[num_EXP], list):
#                 optimizers_dict[exp_name] = torch.optim.lr_scheduler.MultiStepLR(
#                     optimizers_dict[exp_name], config_EXP.TRAIN.LR_STEP[num_EXP],
#                     config_EXP.TRAIN.LR_FACTOR[num_EXP], 0
#                 )
#             else:
#                 optimizers_dict[exp_name] = torch.optim.lr_scheduler.StepLR(
#                     optimizers_dict[exp_name], config_EXP.TRAIN.LR_STEP[num_EXP],
#                     config_EXP.TRAIN.LR_FACTOR[num_EXP], 0
#                 )
#         else:
#             msg = "NO OPTIMIZER NEEDED"
#             print(msg)
#             logger.info(msg)
#             optimizers_dict[exp_name] = None
#             schedulers_dict[exp_name] = None
#
#         # elif config_EXP.TRAIN.OPTIMIZER == 'rmsprop':
#         #     optimizer = optim.RMSprop(
#         #         filter(lambda p: p.requires_grad, model.parameters()),
#         #         lr=config_EXP.TRAIN.LR[num_EXP],
#         #         momentum=config_EXP.TRAIN.MOMENTUM[num_EXP],
#         #         weight_decay=config_EXP.TRAIN.WD[num_EXP],
#         #         alpha=config_EXP.TRAIN.RMSPROP_ALPHA[num_EXP],
#         #         centered=config_EXP.TRAIN.RMSPROP_CENTERED[num_EXP]
#         #     )
#
#     return optimizers_dict, schedulers_dict

def merge_exp_qblock_config(config_EXP, index, config_Q):
    config_Q.defrost()
    config_Q.OPERATION = config_EXP.OPERATION
    config_Q.NUM_BITS = config_EXP.ACT_NUM_BITS[index]
    config_Q.OUTPUT_BASE_PATH = config_EXP.OUTPUT_BASE_PATH
    config_Q.DEVICE = config_EXP.DEVICES[index]
    config_Q.EXP_NAME = config_EXP.EXP_NAMES[index]
    config_Q.DENSE_SPARSE = config_EXP.DENSE_SPARSE[index]
    config_Q.freeze()

def create_EXP_NAMES(config_EXP):
    config_EXP.defrost()
    for index, qb_config_file in enumerate(config_EXP.QB_CONFIGS):
        name = config_EXP.OPERATION[index] + '_' + config_EXP.QB_CONFIGS[index].split('/')[-1][:-5]
        name += '_w' + str(config_EXP.WEIGHT.NUM_BITS[index]) \
                + '_' + config_EXP.WEIGHT.Q_ALGORITHM[index][:4] +\
                '_a' + str(config_EXP.ACT_NUM_BITS[index])
        if config_EXP.DENSE_SPARSE[index] == True:
            name += '_' + 'DEN_SP'
        config_EXP.EXP_NAMES[index] = name
    config_EXP.freeze()

def create_loggers_writers(config_EXP):
    log_path = os.path.join(config_EXP.OUTPUT_BASE_PATH, config_EXP.LOG)
    if not os.path.exists(log_path):
        os.mkdir(log_path)
    writer_path = os.path.join(config_EXP.OUTPUT_BASE_PATH, config_EXP.LOG)
    if not os.path.exists(writer_path):
        os.mkdir(writer_path)

    loggers_dict = dict()
    writers_dict = dict()
    formatter = logging.Formatter('%(asctime)s : %(message)s')
    for num_EXP, exp_name in enumerate(config_EXP.EXP_NAMES):
        # logger creation for every exp
        if 'validate_FULL_w32_asym_a32' in config_EXP.EXP_NAMES:
            log_file = os.path.join(log_path, exp_name + '_full_comp' +'.log')
        else:
            log_file = os.path.join(log_path, exp_name + '.log')
        open(log_file, "w").close()
        handler = logging.FileHandler(log_file)
        handler.setFormatter(formatter)
        logger = logging.getLogger(exp_name)
        logger.setLevel(logging.INFO)
        logger.addHandler(handler)
        loggers_dict[exp_name] = logger
        # writer creation for every exp
        writer_exp_dir = os.path.join(writer_path, exp_name)
        if not os.path.exists(writer_exp_dir):
            os.mkdir(writer_exp_dir)
        writer = SummaryWriter(logdir=writer_exp_dir)
        writers_dict[exp_name] = writer
    return loggers_dict, writers_dict





