import torch
import torch.nn as nn
import numpy as np
import os


"""Class used for storing and quantization of weights"""

class BinOp:
    def __init__(self, model, num_bits, q_type):

        self.weights_bit_count = num_bits
        self.model = model

        self.saved_params = []
        self.target_modules = []
        self.num_of_params = None
        self.skipped_modules = []

        if q_type == 'symmetric':
            self.__quantize = self.__quantize_multibit_symmetrically
        elif q_type == 'asymmetric':
            self.__quantize = self.__quantize_multibit_asymmetrically

        for name, m in model.named_modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                if hasattr(m, 'quantize_weights') and not m.quantize_weights:
                    print('BinOp: skipping layer {}'.format(name))
                    self.skipped_modules.append(m)
                else:
                    tmp = m.weight.detach().clone()
                    self.saved_params.append(tmp)
                    self.target_modules.append(m.weight)

                    print('BinOp: quantizing weights of {}'.format(name))
        self.num_of_params = len(self.target_modules)

    def binarization(self):
        if self.weights_bit_count == 32:
            return None
        self.save_params()
        self.__quantize()
        return None


    def save_params(self):
        for index in range(self.num_of_params):
            self.saved_params[index].copy_(self.target_modules[index])

    def restore(self):
        for index in range(self.num_of_params):
            self.target_modules[index].data.copy_(self.saved_params[index])

    def __quantize_multibit_symmetrically(self):
        N = 2 ** (self.weights_bit_count - 1) - 1
        for m in self.model.modules():
            if m in self.skipped_modules:
                continue
            if isinstance(m, nn.Conv2d):
                w = m.weight.data
                save_shape = w.shape
                w.steps = torch.max(torch.max(torch.max(torch.abs(w), 3)[0], 2)[0], 1)[0]
                torch.mul(w.steps, 1.0 / N, out=w.steps)
                w.steps = torch.reshape(w.steps, (save_shape[0], 1, 1, 1))
                torch.div(w, w.steps.float(), out=w)
                torch.round(w, out=w)
                torch.clamp(w, -N, N, out=w)
                torch.mul(w, w.steps, out=w)
        return None

    def __quantize_multibit_asymmetrically(self):
        N = 2**self.weights_bit_count - 1
        for m in self.model.modules():
            if m in self.skipped_modules:
                continue
            if isinstance(m, nn.Conv2d):
                w = m.weight.data
                save_shape = w.shape
                w_min = torch.min(torch.min(torch.min(w, 3)[0], 2)[0], 1)[0].detach()
                w_min = torch.reshape(w_min, (save_shape[0], 1, 1, 1))
                w_max = torch.max(torch.max(torch.max(w, 3)[0], 2)[0], 1)[0].detach()
                w_max = torch.reshape(w_max, (save_shape[0], 1, 1, 1))
                w.steps = w_max - w_min
                w.steps = torch.mul(w.steps, 1.0 / N)
                # w.steps = torch.reshape(w.steps, (save_shape[0], 1, 1, 1))
                torch.div(w-w_min, w.steps.float(), out=w)
                torch.round(w, out=w)
                torch.clamp(w, 0, N, out=w)
                torch.mul(w, w.steps, out=w)
                w.add_(w_min)
        return None
