import torch
import torch.nn as nn
from NewQuantizationBlock import QuantizationBlock


def forward_module_list(mod_list, input, features, collect_statistics=False, qb_id_collect=None):
    out = input
    for i, mod_t in enumerate(mod_list):
        if isinstance(mod_t, nn.Sequential):
            for j, layer in enumerate(mod_t):
                if isinstance(layer, QuantizationBlock):
                    features = layer.predict_qp(out, features)
                    out = layer(out, collect_statistics=collect_statistics, qb_id_collect=qb_id_collect)
                else:
                    out = layer(out)
        else:
            if isinstance(mod_t, QuantizationBlock):
                features = mod_t.predict_qp(out, features)
                out = mod_t(out, collect_statistics=collect_statistics, qb_id_collect=qb_id_collect)
            else:
                out = mod_t(out)
    return out, features


class VGG(nn.Module):

    def __init__(self, QB_CONFIG, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.q_block_feature = QuantizationBlock(QB_CONFIG, is_quantized=True, qblock_id='before_classifier_qblock')
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            QuantizationBlock(QB_CONFIG, is_quantized=True, qblock_id='classifier_1_qblock'),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            QuantizationBlock(QB_CONFIG, is_quantized=True, qblock_id='classifier_2_qblock'),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        self.classifier[8].quantize_weights = False
        if init_weights:
            self._initialize_weights()

    def forward(self, x, collect_statistics=False, qb_id_collect=None):
        self.q_features = (torch.tensor([]), torch.tensor([]))
        x, self.q_features = forward_module_list(self.features, x, features=self.q_features,
                                                 collect_statistics=collect_statistics,
                                                 qb_id_collect=qb_id_collect)
        x = self.avgpool(x)
        self.q_features = self.q_block_feature.predict_qp(x, self.q_features)
        x = self.q_block_feature(x, collect_statistics=collect_statistics, qb_id_collect=qb_id_collect)
        x = torch.flatten(x, 1)

        x, self.q_features = forward_module_list(self.classifier, x, features=self.q_features,
                                                 collect_statistics=collect_statistics,
                                                 qb_id_collect=qb_id_collect)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def proccess_model_statistics(self):
        for name, m in self.named_modules():
            if isinstance(m, QuantizationBlock):
                if m.getattr_statistic_container() is not None and len(m.getattr_statistic_container()['MEAN']) != 0:
                    m.process_statistics_container()

    def dump_all_dense_fractions(self):
        for name, m in self.named_modules():
            if isinstance(m, QuantizationBlock):
                m.dump_dense_fraction()


def make_layers(QB_CONFIG, cfg, batch_norm=False):
    layers = []
    in_channels = 3
    counter_conv = 1
    for index, v in enumerate(cfg):
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if index == 0:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                conv2d.quantize_weights = False
                q_block = QuantizationBlock(QB_CONFIG, is_quantized=False, qblock_id='input_qblock')
                if batch_norm:
                    layers += [q_block, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [q_block, conv2d, nn.ReLU(inplace=True)]
                in_channels = v
            else:
                q_block = QuantizationBlock(QB_CONFIG, is_quantized=True, qblock_id='conv_' + str(counter_conv))
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [q_block, conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [q_block, conv2d, nn.ReLU(inplace=True)]
                in_channels = v
                counter_conv += 1
    return nn.Sequential(*layers)


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def _vgg(QB_CONFIG, cfg, batch_norm, **kwargs):
    model = VGG(QB_CONFIG,make_layers(QB_CONFIG, cfgs[cfg], batch_norm=batch_norm), **kwargs)

    return model


def vgg16(QB_CONFIG, **kwargs):
    r"""VGG 16-layer model (configuration "D")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg(QB_CONFIG, 'D', False, **kwargs)


def vgg16_bn(QB_CONFIG, **kwargs):
    r"""VGG 16-layer model (configuration "D")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg(QB_CONFIG, 'D', True, **kwargs)
