import torch
from os import path as osp
import importlib
import os
from copy import deepcopy
# automatically scan and import arch modules
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'


# arch_folder = osp.dirname(osp.abspath(__file__))

def scandir(dir_path, suffix=None, recursive=False, full_path=False):
    """Scan a directory to find the interested files.

    Args:
        dir_path (str): Path of the directory.
        suffix (str | tuple(str), optional): File suffix that we are
            interested in. Default: None.
        recursive (bool, optional): If set to True, recursively scan the
            directory. Default: False.
        full_path (bool, optional): If set to True, include the dir_path.
            Default: False.

    Returns:
        A generator for all the interested files with relative pathes.
    """

    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
        raise TypeError('"suffix" must be a string or tuple of strings')

    root = dir_path

    def _scandir(dir_path, suffix, recursive):
        for entry in os.scandir(dir_path):
            if not entry.name.startswith('.') and entry.is_file():
                if full_path:
                    return_path = entry.path
                else:
                    return_path = osp.relpath(entry.path, root)

                if suffix is None:
                    yield return_path
                elif return_path.endswith(suffix):
                    yield return_path
            else:
                if recursive:
                    yield from _scandir(
                        entry.path, suffix=suffix, recursive=recursive)
                else:
                    continue

    return _scandir(dir_path, suffix=suffix, recursive=recursive)


arch_folder = "b2b/models/deblur/nafnet/archs"
arch_filenames = [
    osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
    if v.endswith('_arch.py')
]
# import all the arch modules
_arch_modules = [
    importlib.import_module(f'b2b.models.deblur.nafnet.archs.{file_name}')
    for file_name in arch_filenames
]

def dynamic_instantiation(modules, cls_type, opt):
    """Dynamically instantiate class.

    Args:
        modules (list[importlib modules]): List of modules from importlib
            files.
        cls_type (str): Class type.
        opt (dict): Class initialization kwargs.

    Returns:
        class: Instantiated class.
    """

    for module in modules:
        cls_ = getattr(module, cls_type, None)
        if cls_ is not None:
            break
    if cls_ is None:
        raise ValueError(f'{cls_type} is not found.')
    return cls_(**opt)


def define_network(opt):
    network_type = opt.pop('type')
    net = dynamic_instantiation(_arch_modules, network_type, opt)
    return net

def _print_different_keys_loading(crt_net, load_net, strict=True):
    """Print keys with differnet name or different size when loading models.

    1. Print keys with differnet names.
    2. If strict=False, print the same key but with different tensor size.
        It also ignore these keys with different sizes (not load).

    Args:
        crt_net (torch model): Current network.
        load_net (dict): Loaded network.
        strict (bool): Whether strictly loaded. Default: True.
    """
    # crt_net = self.get_bare_model(crt_net)
    crt_net = crt_net.state_dict()
    crt_net_keys = set(crt_net.keys())
    load_net_keys = set(load_net.keys())

    if crt_net_keys != load_net_keys:
        # logger.warning('Current net - loaded net:')
        for v in sorted(list(crt_net_keys - load_net_keys)):
            # logger.warning(f'  {v}')
            pass
        # logger.warning('Loaded net - current net:')
        for v in sorted(list(load_net_keys - crt_net_keys)):
            # logger.warning(f'  {v}')
            pass

    # check the size for the same keys
    if not strict:
        common_keys = crt_net_keys & load_net_keys
        for k in common_keys:
            if crt_net[k].size() != load_net[k].size():
                # logger.warning(
                #     f'Size different, ignore [{k}]: crt_net: '
                #     f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
                load_net[k + '.ignore'] = load_net.pop(k)

def load_network( net, load_path, strict=True, param_key='params'):
    # breakpoint()
    load_net = torch.load(
        load_path, map_location=lambda storage, loc: storage)
    if param_key is not None:
        load_net = load_net[param_key]
    print(' load net keys', load_net.keys)
    # remove unnecessary 'module.'
    for k, v in deepcopy(load_net).items():
        if k.startswith('module.'):
            load_net[k[7:]] = v
            load_net.pop(k)
    _print_different_keys_loading(net, load_net, strict)
    net.load_state_dict(load_net, strict=strict)