import torch
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from scipy.stats import entropy
from sklearn.svm import SVR
import time
import json
import pickle

from argparse import Namespace
import os
import warnings
# from defaultConfigQB import _EXPs as config

warnings.filterwarnings("ignore")


"""Quantization Layer:
SUPPORT the following regimes:
1) Dynamic Quantization 
2) Static Quantization
3) QPP quantization
4) Statistic Collection 
5) Dense+Sparse Dynamic
6) Dense+Sparse Static
7) Dense+Sparse QPP
8) Dense fraction collection (QPP, Static)
"""


class Dynamic(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, argsTensor):
        acts_q_bit = argsTensor[0]
        up_quantile = argsTensor[1]
        low_quantile = argsTensor[2]
        mode = argsTensor[3]
        quantize_negative = input.min() < 0
        if up_quantile == 1:
            up_threshold = torch.max(torch.max(torch.max(input, 3)[0], 2)[0], 1)[0]
        else:
            up_threshold = Dynamic.__quantile(input, up_quantile)
        if quantize_negative:
            if low_quantile == 1:
                low_threshold = torch.min(torch.min(torch.min(input, 3)[0], 2)[0], 1)[0]
            else:
                low_threshold = - Dynamic.__quantile(- input, low_quantile)
        else:
            low_threshold = torch.zeros(up_threshold.shape).to(input)
        scale_factor = 2 ** acts_q_bit / (up_threshold - low_threshold)

        scale_factor = torch.reshape(scale_factor, (-1, 1, 1, 1))
        out = torch.round(input * scale_factor) / scale_factor
        out = torch.min(torch.max(out, low_threshold.reshape(-1, 1, 1, 1)),
                        up_threshold.reshape(-1, 1, 1, 1))

        if mode == 1:  # DENSE + SPARSE
            dense_mask = (input <= up_threshold.reshape(-1, 1, 1, 1)) & \
                         (input >= low_threshold.reshape(-1, 1, 1, 1))

            sparse_mask = (input < low_threshold.reshape(-1, 1, 1, 1)) | \
                          (input > up_threshold.reshape(-1, 1, 1, 1))

            out = input * sparse_mask + out * dense_mask
        ctx.save_for_backward(input, out, argsTensor)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, out, argsTensor = ctx.saved_tensors
        grad_input = grad_output

        acts_q_bit = argsTensor[0]
        up_quantile = argsTensor[1]
        low_quantile = argsTensor[2]
        quantize_negative = input.min() < 0
        if up_quantile == 1:
            up_threshold = torch.max(torch.max(torch.max(input, 3)[0], 2)[0], 1)[0]
        else:
            up_threshold = Dynamic.__quantile(input, up_quantile)
        if quantize_negative:
            if low_quantile == 1:
                low_threshold = torch.min(torch.min(torch.min(input, 3)[0], 2)[0], 1)[0]
            else:
                low_threshold = - Dynamic.__quantile(- input, low_quantile)
        else:
            low_threshold = torch.zeros(up_threshold.shape).to(input)

        up_threshold = up_threshold.reshape(-1, 1, 1, 1)
        low_threshold = low_threshold.reshape(-1, 1, 1, 1)
        mask = (input > low_threshold) & (input < up_threshold)
        grad_input *= mask
        return grad_input, None, None

    @staticmethod
    def fill_args_tensor(cfg, up_scale_factor=None, low_scale_factor=None):
        tensor_list = Dynamic.__get_tensor_list(cfg)
        return Tensor(tensor_list)

    @staticmethod
    def __get_tensor_list(cfg, up_scale_factor=None, low_scale_factor=None):
        acts_q_bit = cfg.NUM_BITS
        up_quantile = cfg.DYNAMIC.UP_QUANTILE
        low_quantile = cfg.DYNAMIC.LOW_QUANTILE
        mode = 0
        if cfg.DENSE_SPARSE:
            mode = 1
        tensor_list = [
            acts_q_bit,  # 0
            up_quantile,
            low_quantile,
            mode
        ]
        return tensor_list

    @staticmethod
    def __quantile(x, q):
        # print("input shape{} q {}".format(x.shape, q))
        input = x.reshape(x.shape[0], -1)
        input_shape = input.shape[1]
        number_of_positive_elements_per_batch = torch.sum(input > 0, dim=1)
        number_of_zeros = torch.sum(input == 0, dim=1)
        if torch.sum(number_of_positive_elements_per_batch) == 0:
            return torch.zeros(input.shape[0]).to(x)

        if input.min() < 0:
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch + torch.round(
                number_of_zeros / torch.tensor(2.))
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch.long()
        else:
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch + number_of_zeros


        in_sorted, in_argsort = torch.sort(input, dim=1)
        positions = torch.tensor(q) * (number_of_positive_elements_per_batch - 1)
        floored = torch.floor(positions).long()
        ceiled = (floored + 1).long()
        ceiled[ceiled > (input.shape[1] - 1)] = input.shape[1] - 1
        weight_ceiled = positions - floored
        weight_floored = 1.0 - weight_ceiled
        d0 = in_sorted[:, floored + input_shape - 1 - number_of_positive_elements_per_batch] * weight_floored
        d1 = in_sorted[:, ceiled + input_shape - 1 - number_of_positive_elements_per_batch] * weight_ceiled
        result = torch.diag((d0 + d1)).view(input.shape[0])
        return result


class Static(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, argsTensor):
        acts_q_bit = argsTensor[0]
        up_threshold = argsTensor[1]
        low_threshold = argsTensor[2]
        mode = argsTensor[3]
        quantize_negative = input.min() < 0

        if quantize_negative:
            scale_factor = 2 ** acts_q_bit / (up_threshold - low_threshold)
        else:
            low_threshold = 0
            scale_factor = 2 ** acts_q_bit / (up_threshold - low_threshold)


        scale_factor = torch.reshape(scale_factor, (-1, 1, 1, 1))

        out = torch.clamp(torch.round(input * scale_factor) / scale_factor,
                          low_threshold, up_threshold)

        if mode == 1:  # DENSE + SPARSE
            dense_mask = (input <= up_threshold) & \
                         (input >= low_threshold)

            sparse_mask = (input < low_threshold) | \
                          (input > up_threshold)

            out = input * sparse_mask + out * dense_mask

        ctx.save_for_backward(input, out, argsTensor)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, out, argsTensor = ctx.saved_tensors
        grad_input = grad_output
        quantize_negative = input.min() < 0
        if quantize_negative:
            up_threshold = argsTensor[1]
            low_threshold = argsTensor[2]
        else:
            up_threshold = argsTensor[1]
            low_threshold = 0

        mask = (input > low_threshold) & (input < up_threshold)
        grad_input *= mask
        return grad_input, None, None

    @staticmethod
    def fill_args_tensor(cfg, up_scale_factor=None, low_scale_factor=None):
        tensor_list = Static.__get_tensor_list(cfg, up_scale_factor, low_scale_factor)
        return Tensor(tensor_list)

    @staticmethod
    def __get_tensor_list(cfg, up_scale_factor=None, low_scale_factor=None):
        acts_q_bit = cfg.NUM_BITS
        up_sf = up_scale_factor if up_scale_factor is not None else -1
        low_sf = low_scale_factor if low_scale_factor is not None else -1
        mode = 0
        if cfg.DENSE_SPARSE:
            mode = 1
        tensor_list = [
            acts_q_bit,  # 0
            up_sf,  # 1
            low_sf,
            mode
        ]
        return tensor_list


class QPP(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, argsTensor):
        if argsTensor.shape[0] == 1:
            return input
        acts_q_bit = argsTensor[0]
        mode = argsTensor[1]
        up_threshold = argsTensor[2:2 + input.shape[0]]
        low_threshold = argsTensor[2 + input.shape[0]:]
        quantize_negative = input.min() < 0
        # print("thesholds in QPP", up_threshold, low_threshold)
        if not quantize_negative:
            low_threshold = torch.zeros(up_threshold.shape).to(input)


        scale_factor = 2 ** acts_q_bit / (up_threshold - low_threshold)
        scale_factor = scale_factor.reshape((-1, 1, 1, 1))
        out = torch.round(input * scale_factor) / scale_factor
        out = torch.min(torch.max(out, low_threshold.view(-1, 1, 1, 1)),
                        up_threshold.view(-1, 1, 1, 1))

        if mode == 1:  # DENSE + SPARSE
            dense_mask = (input <= up_threshold.reshape(-1, 1, 1, 1)) & \
                         (input >= low_threshold.reshape(-1, 1, 1, 1))

            sparse_mask = (input < low_threshold.reshape(-1, 1, 1, 1)) | \
                          (input > up_threshold.reshape(-1, 1, 1, 1))

            out = input * sparse_mask + out * dense_mask
        ctx.save_for_backward(input, out, argsTensor)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, out, argsTensor = ctx.saved_tensors
        grad_input = grad_output
        if argsTensor.shape[0] == 1:
            return input
        acts_q_bit = argsTensor[0]
        up_threshold = argsTensor[1:1 + input.shape[0]]
        low_threshold = argsTensor[1 + input.shape[0]:]

        up_threshold = up_threshold.reshape(-1, 1, 1, 1)
        low_threshold = low_threshold.reshape(-1, 1, 1, 1)

        mask = (input > low_threshold) & (input < up_threshold)
        grad_input *= mask
        return grad_input, None, None

    @staticmethod
    def fill_args_tensor(cfg, up_scale_factor=None, low_scale_factor=None):
        tensor_list = QPP.__get_tensor_list(cfg, up_scale_factor, low_scale_factor)
        return Tensor(tensor_list)

    @staticmethod
    def __get_tensor_list(cfg, up_scale_factor=None, low_scale_factor=None):
        acts_q_bit = cfg.NUM_BITS
        mode = 0
        if cfg.DENSE_SPARSE:
            mode = 1
        tensor_list = [
            acts_q_bit,
            mode
        ]

        if up_scale_factor is not None:
            for sf in up_scale_factor:
                tensor_list.append(sf)

        if low_scale_factor is not None:
            for sf in low_scale_factor:
                tensor_list.append(sf)

        return tensor_list


class QuantizationBlock(nn.Module):
    FIX_POINT_DYNAMIC = 'dynamic'
    FIX_POINT_QPP = 'QPP'
    FIX_POINT_STATIC = 'static'

    def __init__(self, cfg, quantize_negative=False, is_quantized=True, qblock_id=None):
        super(QuantizationBlock, self).__init__()
        self.cfg = cfg
        self.quantization = None
        self.qblock_id = qblock_id
        self.quantize_negative = quantize_negative
        self.is_quantized = is_quantized
        self.dense_fraction = []
        # Static and QPP scale factors
        self.up_threshold = None
        self.low_threshold = None
        # QPP predictors
        self.__up_predictor = None
        self.__low_predictor = None
        self.__up_scaler = None
        self.__low_scaler = None
        self.up_QPP_positions = None
        self.low_QPP_positions = None
        self.low_QPP_features = None
        self.up_QPP_features = None
        self.up_QPP_position = False
        self.low_QPP_position = False
        # dynamic
        self.__up_dynamic_quantile = None
        self.__low_dynamic_quantile = None
        # collect statistics
        self.__statistics_dict = None
        self.__statistics_container = None
        self.__statistics_function_container = None
        if self.cfg.NUM_BITS >= 32 or not is_quantized:
            if self.cfg.COLLECT.APPLY:
                if self.__statistics_container is None:
                    self.__init_statistics()
                self.forward = self.__forward_collect_statistics
            else:
                self.forward = self.__forward_quantize
                self.quantization = None
                print('{}:\tnot activated!!!'.format(qblock_id))
        else:
            self.forward = self.__forward_quantize
            if self.cfg.DYNAMIC.APPLY:
                self.__up_dynamic_quantile = self.cfg.DYNAMIC.UP_QUANTILE
                self.__low_dynamic_quantile = self.cfg.DYNAMIC.LOW_QUANTILE
                self.quantization = Dynamic.apply
                self.fill_args_tensor = Dynamic.fill_args_tensor
                print('{} quantized {}'.format(self.cfg.EXP_NAME, self.qblock_id))
            elif self.cfg.QPP.APPLY:
                self.quantization = QPP.apply
                self.__load_predictor()
                self.__QPP_init()
                self.fill_args_tensor = QPP.fill_args_tensor
                print('{} quantized {}'.format(self.cfg.EXP_NAME, self.qblock_id))
            elif self.cfg.STATIC.APPLY:
                self.quantization = Static.apply
                self.__load_thresholds()
                self.fill_args_tensor = Static.fill_args_tensor
                print('{} quantized {}'.format(self.cfg.EXP_NAME, self.qblock_id))
            if self.cfg.COLLECT.APPLY:
                if self.__statistics_container is None:
                    self.__init_statistics()
                self.forward = self.__forward_collect_statistics

        if self.quantization is not None:
            self.argsTensorCPU = self.fill_args_tensor(self.cfg, self.up_threshold,
                                                       self.low_threshold)
            self.argsTensor = self.argsTensorCPU.to(cfg.DEVICE)

    # COLLECT STATISTICS BLOCK
    # ________________________________________________

    def __forward_collect_statistics(self, x, collect_statistics=False):
        if collect_statistics:
            for key in self.__statistics_container.keys():
                if key == 'UP_QUANTILES':
                    self.__statistics_container[key].extend(self.__statistics_function_container[key](x,
                                                                                                      self.cfg.COLLECT.UP_QUANTILES).tolist())
                elif key == 'LOW_QUANTILES':
                    self.__statistics_container[key].extend((-self.__statistics_function_container[key](-x,
                                                                                                        self.cfg.COLLECT.LOW_QUANTILES)).tolist())
                else:
                    self.__statistics_container[key].extend(self.__statistics_function_container[key](x))
        x = self.__forward_quantize(x)
        return x

    def getattr_statistic_container(self):
        return self.__statistics_container

    def __init_statistics(self):
        self.__statistics_dict = dict()
        self.__statistics_container = dict()
        self.__statistics_function_container = dict()

        if self.cfg.COLLECT.LOW_QUANTILES:
            self.__statistics_container["LOW_QUANTILES"] = []
            self.__statistics_function_container["LOW_QUANTILES"] = self.__compute_quantiles
        else:
            print('{}: skipping self.__compute_low_quantiles...'.format(self.qblock_id))
        if self.cfg.COLLECT.UP_QUANTILES:
            self.__statistics_container["UP_QUANTILES"] = []
            self.__statistics_function_container["UP_QUANTILES"] = self.__compute_quantiles
        else:
            print('{}: skipping self.__compute_up_quantiles...'.format(self.qblock_id))
        if self.cfg.COLLECT.MEAN:
            self.__statistics_function_container["MEAN"] = self.__compute_mean
            self.__statistics_container["MEAN"] = []
        else:
            print('{}: skipping self.__compute_mean...'.format(self.qblock_id))

        if self.cfg.COLLECT.STD:
            self.__statistics_function_container["STD"] = self.__compute_std
            self.__statistics_container["STD"] = []
        else:
            print('{}: skipping self.__compute_std...'.format(self.qblock_id))

        if self.cfg.COLLECT.MAX:
            self.__statistics_function_container["MAX"] = self.__compute_max
            self.__statistics_container["MAX"] = []
        else:
            print('{}: skipping self.__compute_max...'.format(self.qblock_id))

        if self.cfg.COLLECT.MIN:
            self.__statistics_function_container["MIN"] = self.__compute_min
            self.__statistics_container["MIN"] = []
        else:
            print('{}: skipping self.__compute_min...'.format(self.qblock_id))

        if self.cfg.COLLECT.ABS_MAX:
            self.__statistics_function_container["ABS_MAX"] = self.__compute_abs_max
            self.__statistics_container["ABS_MAX"] = []
        else:
            print('{}: skipping self.__compute_abs_max...'.format(self.qblock_id))

    @staticmethod
    def __compute_mean(x):
        if x is not None:
            if x.shape[0] == 1:
                return [torch.mean(x, axis=(1, 2, 3)).data.cpu().item()]
            return torch.mean(x, axis=(1, 2, 3)).data.cpu().tolist()

    @staticmethod
    def __compute_std(x):
        if x is not None:
            if x.shape[0] == 1:
                return [torch.std(x, axis=(1, 2, 3)).data.cpu().item()]
            return torch.std(x, axis=(1, 2, 3)).data.cpu().tolist()

    @staticmethod
    def __compute_abs_max(x):
        if x is not None:
            if x.shape[0] == 1:
                return [torch.max(torch.max(torch.max(torch.abs(x), 3)[0], 2)[0], 1)[0].data.cpu().item()]
            return torch.max(torch.max(torch.max(torch.abs(x), 3)[0], 2)[0], 1)[0].data.cpu().tolist()

    @staticmethod
    def __compute_max(x):
        if x is not None:
            if x.shape[0] == 1:
                return [torch.max(torch.max(torch.max(x, 3)[0], 2)[0], 1)[0].data.cpu().item()]
            return torch.max(torch.max(torch.max(x, 3)[0], 2)[0], 1)[0].data.cpu().tolist()

    @staticmethod
    def __compute_min(x):
        if x is not None:
            if x.shape[0] == 1:
                return [torch.min(torch.min(torch.min(x, 3)[0], 2)[0], 1)[0].data.cpu().item()]
            return torch.min(torch.min(torch.min(x, 3)[0], 2)[0], 1)[0].data.cpu().tolist()

    # compute quantile only for positive part of input;
    @staticmethod
    def __quantile(in_sorted, q):
        # print("input shape{} q {}".format(x.shape, q))
        # print("input shape{} q {}".format(x.shape, q))
        input = in_sorted.reshape(in_sorted.shape[0], -1)
        input_shape = input.shape[1]
        number_of_positive_elements_per_batch = torch.sum(input > 0, dim=1)
        number_of_zeros = torch.sum(input == 0, dim=1)
        if torch.sum(number_of_positive_elements_per_batch) == 0:
            return torch.zeros(input.shape[0]).to(in_sorted)

        if input.min() < 0:
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch + torch.round(
                number_of_zeros / torch.tensor(2.))
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch.long()
        else:
            number_of_positive_elements_per_batch = number_of_positive_elements_per_batch + number_of_zeros

        positions = torch.tensor(q) * (number_of_positive_elements_per_batch - 1)

        floored = torch.floor(positions).long()
        ceiled = (floored + 1).long()
        ceiled[ceiled > (input.shape[1] - 1)] = input.shape[1] - 1
        weight_ceiled = positions - floored
        weight_floored = 1.0 - weight_ceiled
        d0 = in_sorted[:, floored + input_shape - 1 - number_of_positive_elements_per_batch] * weight_floored
        d1 = in_sorted[:, ceiled + input_shape - 1 - number_of_positive_elements_per_batch] * weight_ceiled
        result = torch.diag((d0 + d1)).view(input.shape[0])
        # if q == 0.99:
        #     print("COMPUTING QUANTIEL", d0[0][0], d1[0][0], q, positions[0], in_sorted.shape, result[0], ceiled[0], floored[0],number_of_positive_elements_per_batch[0])
        #     print("in_sorted QUANTILE", in_sorted[0])
        return result

    def __compute_quantiles(self, x, quantiles):
        quant_values = torch.zeros((len(quantiles), x.shape[0]))
        if x is not None:
            x_sorted, x_argsort = torch.sort(x.reshape(x.shape[0], -1), dim=1)
            for i, quantile in enumerate(quantiles):
                quant_values[i, :] = self.__quantile(x_sorted, quantile).data.cpu()

        return torch.transpose(quant_values, 0, 1)

    def process_statistics_container(self):
        for key in self.__statistics_container.keys():
            print("Proccess statistics for all model")
            if key == 'MEAN':
                self.__statistics_dict["MEAN"] = self.__statistics_container[key]
            if key == 'STD':
                self.__statistics_dict["STD"] = self.__statistics_container[key]
            if key == 'MAX':
                self.__statistics_dict["MAX"] = self.__statistics_container[key]
            if key == 'MIN':
                self.__statistics_dict["MIN"] = self.__statistics_container[key]
            if key == 'ABS_MAX':
                self.__statistics_dict["ABS_MAX"] = self.__statistics_container[key]
            if key == 'LOW_QUANTILES':
                stat_container_np = np.array(self.__statistics_container[key])
                for i, quantile in enumerate(self.cfg.COLLECT.LOW_QUANTILES):
                    self.__statistics_dict[str(quantile)[2:] + "_LOW_QUANTILE"] = stat_container_np[:, i].tolist()
            if key == 'UP_QUANTILES':
                stat_container_np = np.array(self.__statistics_container[key])
                for i, quantile in enumerate(self.cfg.COLLECT.UP_QUANTILES):
                    self.__statistics_dict[str(quantile)[2:] + "_UP_QUANTILE"] = stat_container_np[:, i].tolist()
        self.__dump_statistics()
        self.__dump_thresholds()

    def __dump_thresholds(self):
        up_sc_path = os.path.join(self.cfg.OUTPUT_BASE_PATH, self.cfg.COLLECT.UP_THRESHOLDS_PATH)
        if not os.path.exists(up_sc_path):
            os.mkdir(up_sc_path)
        low_sc_path = os.path.join(self.cfg.OUTPUT_BASE_PATH, self.cfg.COLLECT.LOW_THRESHOLDS_PATH)
        if not os.path.exists(low_sc_path):
            os.mkdir(low_sc_path)

        for key in self.__statistics_dict.keys():
            if 'QUANTILE' in key:
                if 'LOW' in key:
                    t_sc_path = os.path.join(low_sc_path, key)
                else:
                    t_sc_path = os.path.join(up_sc_path, key)

                if not os.path.exists(t_sc_path):
                    os.mkdir(t_sc_path)
                filename = os.path.join(t_sc_path, self.qblock_id + '.npy')
                np.save(filename, np.mean(self.__statistics_dict[key]))
        self.__statistics_dict = None
        self.__statistics_container = None
        self.__statistics_function_container = None
        print("Thresholds dumped".format(self.qblock_id))

    def __dump_statistics(self):
        for key in self.__statistics_dict.keys():

            if isinstance(self.__statistics_dict[key], np.ndarray):
                self.__statistics_dict[key] = self.__statistics_dict[key].tolist()

        STATISTICS_PATH = os.path.join(self.cfg.OUTPUT_BASE_PATH,
                                       self.cfg.COLLECT.STATISTICS_PATH)

        if not os.path.exists(STATISTICS_PATH):
            os.mkdir(STATISTICS_PATH)
        filename = os.path.join(STATISTICS_PATH, self.qblock_id + '.json')
        with open(filename, 'w') as file:
            json.dump(self.__statistics_dict, file)

        print("ALL STATISTICS ARE DUMPED")

    # ________________________________________________
    # FORWARD VALIDATION AND TRAIN

    def __forward_quantize(self, x_i, collect_statistics=False):
        if self.quantization is None or not self.is_quantized or self.cfg.APPLY_FULL:
            return x_i
        self.__updateArgsTensor()
        if len(x_i.shape) == 2:
            x = x_i.unsqueeze(1).unsqueeze(1)
        else:
            x = x_i
        if self.cfg.STATIC.APPLY and self.cfg.DENSE_SPARSE:
            up_threshold = self.argsTensor[1]
            low_threshold = self.argsTensor[2]
            dense_part = ((x <= up_threshold) & \
                          (x >= low_threshold)).float()
            zero_part = (x == 0.0).float()
            dense_sum = torch.sum(dense_part, axis=(1, 2, 3))
            zero_sum = torch.sum(zero_part, axis=(1, 2, 3))
            size_ = x.shape[1] * x.shape[2] * x.shape[3]
            fraction = (dense_sum ) / (size_)
            self.dense_fraction.extend(fraction.detach().data.cpu().tolist())

        if self.cfg.QPP.APPLY and self.cfg.DENSE_SPARSE:
            up_threshold = self.argsTensor[2:2 + x.shape[0]]
            low_threshold = self.argsTensor[2 + x.shape[0]:]
            # if "layer2_bl3_first_layer_basic_block" in self.qblock_id:
            #     print('FORWARD threshold = ', up_threshold[0],x[0].shape, torch.sort(x[0].reshape(1,-1)))
            dense_part = ((x <= up_threshold.reshape(-1, 1, 1, 1)) & \
                          (x >= low_threshold.reshape(-1, 1, 1, 1))).float()
            zero_part = (x == 0.0).float()
            dense_sum = torch.sum(dense_part, axis=(1, 2, 3))
            zero_sum = torch.sum(zero_part, axis=(1, 2, 3))
            size_ = x.shape[1] * x.shape[2] * x.shape[3]
            fraction = (dense_sum ) / (size_)
            # if "layer2_bl3_first_layer_basic_block" in self.qblock_id:
            #     print("fraction = ", fraction, fraction.mean())

            self.dense_fraction.extend(fraction.detach().data.cpu().tolist())

        out = self.quantization(x, self.argsTensor)


        return out.squeeze(1).squeeze(1)

    def dump_dense_fraction(self):
        if self.cfg.DENSE_SPARSE and (self.cfg.STATIC.APPLY or self.cfg.QPP.APPLY):
            FRACTIONS_PATH = os.path.join(self.cfg.OUTPUT_BASE_PATH,
                                          'dense_fractions', self.cfg.EXP_NAME)
            if not os.path.exists(os.path.join(self.cfg.OUTPUT_BASE_PATH,
                                               'dense_fractions')):
                os.mkdir(os.path.join(self.cfg.OUTPUT_BASE_PATH,
                                      'dense_fractions'))

            if not os.path.exists(FRACTIONS_PATH):
                os.mkdir(FRACTIONS_PATH)

            filename = os.path.join(FRACTIONS_PATH, self.qblock_id + '.npy')
            np.save(filename, np.array(self.dense_fraction))
            self.dense_fraction = None

    def __updateArgsTensor(self):
        curT = self.fill_args_tensor(self.cfg, self.up_threshold,
                                     self.low_threshold)
        self.argsTensorCPU = curT
        self.argsTensor = self.argsTensorCPU.to(self.cfg.DEVICE)

    # STATIC LOAD SCALE FACTORS
    def __load_thresholds(self):

        up_threshold_dir = os.path.join(self.cfg.INPUT_BASE_PATH,
                                        self.cfg.STATIC.UP_THRESHOLDS_PATH)
        low_threshold_dir = os.path.join(self.cfg.INPUT_BASE_PATH,
                                         self.cfg.STATIC.LOW_THRESHOLDS_PATH)

        up_filename = os.path.join(up_threshold_dir, self.qblock_id + '.npy')
        low_filename = os.path.join(low_threshold_dir, self.qblock_id + '.npy')

        self.up_threshold = np.load(up_filename)
        self.low_threshold = np.load(low_filename)

        print('{}: scale factor ({} : {}) is loaded!'.format(self.qblock_id,
                                                             self.up_threshold,
                                                             self.low_threshold))

    # LOAD AND INIT QPP

    def __QPP_init(self):
        if self.qblock_id in self.up_QPP_positions:
            self.up_QPP_position = True
        if self.qblock_id in self.low_QPP_positions:
            self.low_QPP_position = True

    def __load_predictor(self):
        up_predictor_path = os.path.join(self.cfg.INPUT_BASE_PATH,
                                         self.cfg.QPP.UP_PREDICTORS_PATH)
        low_predictor_path = os.path.join(self.cfg.INPUT_BASE_PATH,
                                          self.cfg.QPP.LOW_PREDICTORS_PATH)

        up_predictor_filename = os.path.join(up_predictor_path,
                                             "predictor_" + self.qblock_id + '.sav')
        low_predictor_filename = os.path.join(low_predictor_path,
                                              "predictor_" + self.qblock_id + '.sav')
        up_scaler_filename = os.path.join(up_predictor_path,
                                          "scaler_" + self.qblock_id + '.sav')
        low_scaler_filename = os.path.join(low_predictor_path,
                                           "scaler_" + self.qblock_id + '.sav')
        up_setup_filename = os.path.join(up_predictor_path,
                                         'QPP_setup.json')
        low_setup_filename = os.path.join(low_predictor_path,
                                          'QPP_setup.json')

        with open(up_predictor_filename, 'rb') as fp:
            self.__up_predictor = pickle.load(fp)
        with open(low_predictor_filename, 'rb') as fp:
            self.__low_predictor = pickle.load(fp)

        with open(up_scaler_filename, 'rb') as fp:
            self.__up_scaler = pickle.load(fp)
        with open(low_scaler_filename, 'rb') as fp:
            self.__low_scaler = pickle.load(fp)

        with open(up_setup_filename, 'r') as fp:
            t_QPP_setup = json.load(fp)
            self.up_QPP_positions = t_QPP_setup["QPP_POSITIONS"]
            self.up_QPP_features = t_QPP_setup["FEATURES"]
            # self.up_NAME = t_QPP_setup["EXP_NAME"]
        with open(low_setup_filename, 'r') as fp:
            t_QPP_setup = json.load(fp)
            self.low_QPP_positions = t_QPP_setup["QPP_POSITIONS"]
            self.low_QPP_features = t_QPP_setup["FEATURES"]
            # self.low_NAME = t_QPP_setup["EXP_NAME"]
        # print("QPP setup: {}, {},{},{}".format(self.up_QPP_positions, self.up_QPP_features,
        #                                        self.low_QPP_positions, self.low_QPP_features))

    # PREDICT QP

    def gen_features(self, x, type='UP'):
        x_sorted, _ = torch.sort(x.reshape(x.shape[0], -1), dim=1)
        x_sorted_reversed = reversed(x_sorted)
        if type == 'UP':
            features_names = self.up_QPP_features
            features_num = len(features_names)
            features_all = torch.zeros((x.shape[0], features_num))
        else:
            features_names = self.low_QPP_features
            features_num = len(features_names)
            features_all = torch.zeros((x.shape[0], features_num))
        for i in range(features_num):

            if features_names[i] == 'MEAN':
                features_all[:, i] = torch.mean(x, axis=(1, 2, 3))
            elif features_names[i] == 'STD':
                features_all[:, i] = torch.std(x, axis=(1, 2, 3))
            elif features_names[i] == 'ABS_MAX':
                features_all[:, i] = torch.max(torch.max(torch.max(torch.abs(x),
                                                                   3)[0], 2)[0], 1)[0]
            elif features_names[i] == 'MAX':
                features_all[:, i] = torch.max(torch.max(torch.max(x,
                                                                   3)[0], 2)[0], 1)[0]
            elif features_names[i] == 'MIN':
                features_all[:, i] = torch.min(torch.min(torch.min(x,
                                                                   3)[0], 2)[0], 1)[0]
            else:
                if features_names[i] < 0:

                    features_all[:, i] = self.__quantile(- x_sorted_reversed, - features_names[i])
                else:

                    features_all[:, i] = self.__quantile(x_sorted, features_names[i])
                    # if "layer2_bl3_first_layer_basic_block" in self.qblock_id and features_names[i] == 0.99 and type == 'UP':
                    #     print("QUNTILE = ", features_all[:, i])
        return features_all

    def predict_qp(self, x, features):
        up_features = features[0]
        low_features = features[1]
        if (self.cfg.QPP.APPLY and self.quantization is not None) or ('input' in self.qblock_id):

            if self.up_QPP_position:

                # print("OLD_FEATURES", up_features)
                up_features_num = len(self.up_QPP_features)
                X = x.detach()
                # if self.qblock_id == 'HRN_stage2_TrLay__TrInd_1_0':
                # print("new featuures: ", self.gen_features(X, 'UP'))
                if len(up_features) == 0:
                    up_features = self.gen_features(X, 'UP')
                    up_cpu_features = up_features.cpu()
                else:
                    t_features = torch.zeros(up_features.shape[0], up_features.shape[1] + up_features_num)
                    step = int(t_features.shape[1] / up_features_num)
                    # print("step :", step)
                    # print("scaler", self.__up_scaler,self.__up_predictor)
                    for i in range(step - 1):
                        t_features[:, i::step] = up_features[:, i::step - 1]
                    t_features[:, step - 1::step] = self.gen_features(X, 'UP')
                    up_features = t_features
                    # print("up_features", up_features)
                    up_cpu_features = up_features.cpu()
                    # print("HI", up_cpu_features)
            else:
                up_cpu_features = up_features.cpu()
            if self.low_QPP_position:
                low_features_num = len(self.low_QPP_features)
                X = x.detach()
                if len(low_features) == 0:
                    low_features = self.gen_features(X, 'LOW')
                    low_cpu_features = low_features.cpu()
                else:
                    t_features = torch.zeros(low_features.shape[0], low_features.shape[1] + low_features_num)
                    step = int(t_features.shape[1] / low_features_num)
                    for i in range(step - 1):
                        t_features[:, i::step] = low_features[:, i::step - 1]
                    t_features[:, step - 1::step] = self.gen_features(X, 'LOW')
                    low_features = t_features
                    low_cpu_features = low_features.cpu()
            else:
                low_cpu_features = low_features.cpu()

            up_cpu_features = up_cpu_features.numpy()
            low_cpu_features = low_cpu_features.numpy()
            if self.__up_scaler is not None:
                up_cpu_features_sc = self.__up_scaler.transform(up_cpu_features)
            else:
                up_cpu_features_sc = up_cpu_features
            if self.__low_scaler is not None:
                low_cpu_features_sc = self.__low_scaler.transform(low_cpu_features)
            else:
                low_cpu_features_sc = low_cpu_features

            if self.__up_predictor is not None:
                self.up_threshold = self.__up_predictor.predict(up_cpu_features_sc)

            else:
                self.up_threshold = np.zeros(up_cpu_features_sc.shape[0])
            if np.sum(self.up_threshold < 0) > 0:

                print("bad_layer", self.qblock_id)
            if self.__low_predictor is not None:
                self.low_threshold = self.__low_predictor.predict(low_cpu_features_sc)
            else:
                self.low_threshold = np.zeros(low_cpu_features_sc.shape[0])
            # print("threshold", self.up_threshold, self.low_threshold )
            # if "layer3_bl0_output_basic_block" in self.qblock_id:
            #     print("threshold up/down = ", self.up_threshold, self.low_threshold)
        return up_features, low_features
