from typing import Tuple, Union, List
import torch
import torch.nn as nn
from models import STGazeNetCombined

def format_memory(num_bytes: int) -> str:
    """Converts a number of bytes to a human-readable format (KB, MB, GB)."""
    if num_bytes < 1024:
        return f"{num_bytes} Bytes"
    elif num_bytes < 1024**2:
        return f"{num_bytes / 1024:.2f} KB"
    elif num_bytes < 1024**3:
        return f"{num_bytes / 1024**2:.2f} MB"
    else:
        return f"{num_bytes / 1024**3:.2f} GB"

def estimate_activation_memory(
    model: nn.Module, 
    batch_size: int = 1,
    face_shape: Tuple[int, int, int] = (3, 128, 128),
    eye_shape: Tuple[int, int, int] = (3, 128, 128),
    sequence_length: int = 30,
    dtype: torch.dtype = torch.float32
) -> dict:
    """
    Estimates the memory needed to store activations for a single forward pass.

    This function performs a dummy forward pass and uses forward hooks to capture
    the output of each module. It's crucial for estimating the memory
    footprint during training, as all these activations are stored for the
    backward pass.

    Args:
        model (nn.Module): The PyTorch model.
        batch_size (int): The batch size for the dummy input tensor.
        Defaults to 1.
        face_shape (Tuple[int, int, int], optional): The shape of the face input tensor.
        Defaults to (3, 128, 128).
        eye_shape (Tuple[int, int, int], optional): The shape of the eye input tensor.
        Defaults to (3, 128, 128).
        sequence_length (int, optional): The length of the sequence of images.
        Defaults to 30.
        dtype (torch.dtype, optional): The data type of the input tensor. 
        Defaults to torch.float32.

    Returns:
        dict: A dictionary containing:
              - 'total_memory_bytes': Total estimated memory for activations in bytes.
              - 'total_memory_str': Human-readable string for the total memory.
              - 'per_module_memory': A list of tuples, where each tuple contains
                                     the module name and its activation memory usage.
    """
    # Ensure model is on a device and in eval mode
    model.eval()
    try:
        device = next(model.parameters()).device
    except StopIteration:
        # Handle models with no parameters (e.g., just activation functions)
        device = torch.device("cpu")
        model.to(device)
        
    # Create a dummy input tensor
    dummy_face_input = torch.randn(batch_size, *face_shape, dtype=dtype, device=device)
    dummy_eye_input = torch.randn(batch_size, *eye_shape, dtype=dtype, device=device)

    # List to store memory of each module's output
    activation_sizes = []
    
    # List to hold hook handles for later removal
    hooks = []

    # --- Hook function ---
    def get_activation_size_hook(module, input, output):
        # This function will be called for each module
        module_memory = 0
        
        # We handle different types of outputs (tensors, tuples, lists)
        if isinstance(output, torch.Tensor):
            module_memory = output.nelement() * output.element_size()
        elif isinstance(output, (list, tuple)):
            for tensor in output:
                if isinstance(tensor, torch.Tensor):
                    module_memory += tensor.nelement() * tensor.element_size()
        
        if module_memory > 0:
            activation_sizes.append((str(module.__class__.__name__), module_memory * sequence_length))

    # --- Register hooks on all modules ---
    for module in model.modules():
        handle = module.register_forward_hook(get_activation_size_hook)
        hooks.append(handle)

    # --- Perform the forward pass ---
    with torch.no_grad():
        if isinstance(model, STGazeNetCombined):
            input_dict = {
                'face_patch': dummy_face_input,
                'left_eye_patch': dummy_eye_input,
                'right_eye_patch': dummy_eye_input
            }
            model(input_dict, get_pog=False)
        else:
            model(dummy_eye_input, dummy_face_input)

    # --- Clean up by removing all hooks ---
    for handle in hooks:
        handle.remove()
        
    # --- Calculate total memory and prepare output ---
    total_bytes = sum(mem for _, mem in activation_sizes)
    
    # NOTE: model.modules() can cause double counting for container modules
    # (like nn.Sequential) whose output is the same as their last child.
    # This provides a safe upper-bound estimate. For most common architectures,
    # the overestimation is minor.

    return {
        'total_memory_bytes': total_bytes,
        'total_memory_str': format_memory(total_bytes),
        'per_module_memory': [(name, format_memory(mem)) for name, mem in activation_sizes]
    }

def set_tags(config, tags=[]):
    if config.start_time > 0:
        tags.append(f"start_time={config.start_time}")
    if config.flip_right_eye:
        tags.append("flip")
    if config.do_offset_augmentation:
        tags.append("offset")
    if config.loss_coeff_PoG_cons_initial == 0:
        tags.append("no_cons")
    if config.st_net_ablation:
        if config.ablation_eye_encoder:
            tags.append("eye_encoder")
            tags.append('_'.join(config.sides))
        if config.ablation_face_encoder:
            tags.append("face_encoder")
        if config.ablation_eca:
            tags.append("eca")
        if config.ablation_sam:
            tags.append("sam")
        if config.ablation_gru:
            tags.append("gru")
    tags.append(config.st_net_model_name)
    if config.early_fusion:
        tags.append("early_fusion")

    return tags