import os
import math
from itertools import islice

from b2b.config            import Args
from b2b.consts            import (
    MODEL_STATE_TRAIN, MODEL_STATE_EVAL, MERGE_NONE
)
from b2b.data              import construct_data_loaders
from b2b.torch.funcs       import get_torch_device_smart, seed_everything
from b2b.cgan              import construct_model
import numpy as np
import torch
import cv2


def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
    """Convert torch Tensors into image numpy arrays.

    After clamping to [min, max], values will be normalized to [0, 1].

    Args:
        tensor (Tensor or list[Tensor]): Accept shapes:
            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
            2) 3D Tensor of shape (3/1 x H x W);
            3) 2D Tensor of shape (H x W).
            Tensor channel should be in RGB order.
        rgb2bgr (bool): Whether to change rgb to bgr.
        out_type (numpy type): output types. If ``np.uint8``, transform outputs
            to uint8 type with range [0, 255]; otherwise, float type with
            range [0, 1]. Default: ``np.uint8``.
        min_max (tuple[int]): min and max values for clamp.

    Returns:
        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
        shape (H x W). The channel order is BGR.
    """
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list)
             and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(
            f'tensor or list of tensors expected, got {type(tensor)}')

    if torch.is_tensor(tensor):
        tensor = [tensor]
    result = []
    for _tensor in tensor:
        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])

        n_dim = _tensor.dim()
        if n_dim == 4:
            img_np = make_grid(
                _tensor, nrow=int(math.sqrt(_tensor.size(0))),
                normalize=False).numpy()
            img_np = img_np.transpose(1, 2, 0)
            if rgb2bgr:
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 3:
            img_np = _tensor.numpy()
            img_np = img_np.transpose(1, 2, 0)
            if img_np.shape[2] == 1:  # gray image
                img_np = np.squeeze(img_np, axis=2)
            elif img_np.shape[2] == 3:
                if rgb2bgr:
                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 2:
            img_np = _tensor.numpy()
        else:
            raise TypeError('Only support 4D, 3D or 2D tensor. '
                            f'But received with dimension: {n_dim}')
        if out_type == np.uint8:
            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
            img_np = (img_np * 255.0).round()
        img_np = img_np.astype(out_type)
        result.append(img_np)
    if len(result) == 1:
        result = result[0]
    return result

def slice_data_loader(loader, batch_size, n_samples = None):
    if n_samples is None:
        return (loader, len(loader))

    # breakpoint()
    steps = min(math.ceil(n_samples / batch_size), len(loader))
    sliced_loader = islice(loader, 0,  len(loader), int(len(loader)/steps))

    return (sliced_loader, steps)

def tensor_to_image(tensor):
    result = tensor.cpu().detach().numpy()

    if tensor.ndim == 4:
        result = result.squeeze(0)

    result = result.transpose((1, 2, 0))
    return result.astype(np.float32)

def override_config(config, config_overrides):
    if config_overrides is None:
        return

    for (k,v) in config_overrides.items():
        config[k] = v

def get_evaldir(root, epoch, mkdir = False):
    if epoch is None:
        result = os.path.join(root, 'evals', 'final')
    else:
        result = os.path.join(root, 'evals', 'epoch_%04d' % epoch)

    if mkdir:
        os.makedirs(result, exist_ok = True)

    return result

def set_model_state(model, state):
    if state == MODEL_STATE_TRAIN:
        model.train()
    elif state == MODEL_STATE_EVAL:
        model.eval()
    else:
        raise ValueError(f"Unknown model state '{state}'")

def start_model_eval(path, epoch, model_state, merge_type, **config_overrides):
    args   = Args.load(path)
    device = get_torch_device_smart()

    override_config(args.config, config_overrides)
    args.config.data.merge_type = merge_type
    model = construct_model(
        args.savedir, args.config, is_train = False, device = device
    )
    # breakpoint()
    if epoch == -1:
        epoch = max(model.find_last_checkpoint_epoch(), 0)

    print("Load checkpoint at epoch %s" % epoch)

    seed_everything(args.config.seed)
    model.load(epoch)

    set_model_state(model, model_state)
    evaldir = get_evaldir(path, epoch, mkdir = True)

    return (args, model, evaldir)

def load_eval_model_dset_from_cmdargs(
    cmdargs, merge_type = MERGE_NONE, **config_overrides
):
    args, model, evaldir = start_model_eval(
        cmdargs.model, cmdargs.epoch, cmdargs.model_state,
        merge_type = merge_type,
        batch_size = cmdargs.batch_size, **config_overrides
    )

    data_it = construct_data_loaders(
        args.config.data, args.config.batch_size, split = cmdargs.split
    )

    return (args, model, data_it, evaldir)

def get_eval_savedir(evaldir, prefix, model_state, split, mkdir = False):
    result = os.path.join(evaldir, f'{prefix}_{model_state}-{split}')

    if mkdir:
        os.makedirs(result, exist_ok = True)

    return result

def make_image_subdirs(model, savedir):
    for name in model.images:
        path = os.path.join(savedir, name)
        os.makedirs(path, exist_ok = True)

