import torch
import torch.nn as nn

import functools


class FeatureExtractor(nn.Module):
    class StopForward(Exception):
        pass
    
    def __init__(self, backbone, input_shape):
        super().__init__()
        self.input_shape = tuple(input_shape)
        self.backbone = backbone
        
        # Register input buffer so the model.to(device) affects the input
        self.register_buffer('dummy_input', torch.randn(1, *input_shape), persistent=False)
        
        if input_shape[0] < 3:
            # Usually the backbone models expect inputs with 3 channels
            self.to_3_channels = nn.Conv2d(input_shape[0], 3, 1)
        
        self._feature_hooks = {}
        self._feature_outputs = None
        self._feature_list = None
        
        self._module_lookup = dict(self.backbone.named_modules())
    
    def forward(self, x):
        assert x.shape[1:] == self.input_shape
        
        if self.input_shape[0] < 3:
            # Extend to 3 channels
            x = self.to_3_channels(x)
        
        if not self._feature_hooks:
            # No hook, just forward with the backbone
            return self.backbone(x)
        
        else:
            self._feature_outputs = {}
            
            try:
                self.backbone(x)
                # If `StopForward` is not raised, check the outputs
                not_captured = set(self._feature_hooks) - set(self._feature_outputs)
                assert not not_captured, 'These features are not captured: {!r}'.format(not_captured)
            except self.StopForward:
                pass
            
            outputs = self._feature_outputs
            self._feature_outputs = None
            
            return outputs

    def add_hook(self, module_name, output_shape):
        hook_key = (module_name, output_shape)
        
        if hook_key in self._feature_hooks:
            # Remove hook if `hook_key` is registered
            self._feature_hooks[hook_key].remove()
        
        def _hook(module, inputs, output):
            if self._feature_outputs is None:
                # Bypass hook if not forwarding
                return

            if output.shape[1:] == output_shape:
                self._feature_outputs[hook_key] = output

                if len(self._feature_outputs) == len(self._feature_hooks):
                    # When there is no output to retrieve, stop forwarding
                    raise self.StopForward()
            
        handle = self._module_lookup[module_name].register_forward_hook(_hook)
        self._feature_hooks[hook_key] = handle
    
    def remove_hooks(self):
        for handle in self._feature_hooks.values():
            handle.remove()
        
        self._feature_hooks.clear()
    
    @property
    def hooks(self):
        return list(self._feature_hooks.keys())

    def get_feature_list(self):
        if self._feature_list is None:
            self._feature_list = []
            
            def hook(name, module, inputs, output):
                output_shape = tuple(output.shape)[1:]
                self._feature_list.append((name, output_shape, module))
            
            handles = [
                module.register_forward_hook(functools.partial(hook, name))
                for name, module in self.backbone.named_modules()
            ]
            
            with torch.no_grad():
                x = self.dummy_input
                if self.input_shape[0] < 3:
                    x = self.to_3_channels(x)
                self.backbone(x)
            
            for handle in handles:
                handle.remove()

        return list(self._feature_list)

    def get_feature_shape_dict(self, skip_dim=1):
        feature_list = self.get_feature_list()
        shape_dict = {}
        
        for module_name, output_shape, _ in feature_list:
            if len(output_shape) == 3: # is 2D with channels
                key = output_shape[skip_dim:]
                shape_dict.setdefault(key, []).append((module_name, output_shape))

        return shape_dict


if __name__ == '__main__':
    from fire import Fire
    from torchvision import models
    
    def test(model):
        backbone = getattr(models, model)()
        featext = FeatureExtractor(backbone, [3, 64, 64])
        forward_order = {}
        
        for i, (name, shape, _) in enumerate(featext.get_feature_list()):
            print('{:>25} {}'.format(name, shape))
            forward_order[name] = i
            
        print()
        
        for size, outs in featext.get_feature_shape_dict().items():
            name, shape = outs[-1]
            order = forward_order[name]
            print('{!r:>15s} {:>25s} {:<3d} {}'.format(size, name, order, shape))
            
    Fire(test)
