import torch

def stochastic_pool2d(img, kernel_size=2, stride=2):
    """
    Stochastic pooling.
    
    Parameters:
        img: Input tensor with shape (N, C, H, W)
        kernel_size: Pooling window size
        stride: Pooling stride
    
    Returns:
        Pooled tensor
    """
    # Get input shape
    N, C, H, W = img.shape

    # Unfold image into non-overlapping windows
    unfold = torch.nn.functional.unfold(img, kernel_size=kernel_size, stride=stride)  # (N, C*k*k, L)
    unfold = unfold.view(N, C, kernel_size * kernel_size, -1)  # (N, C, k*k, L)

    # Randomly select a value within each window
    rand_idx = torch.randint(kernel_size * kernel_size, (N, C, unfold.shape[-1]), device=img.device)
    pooled = unfold[torch.arange(N).unsqueeze(1).unsqueeze(2),
                    torch.arange(C).unsqueeze(0).unsqueeze(2),
                    rand_idx,  # Random index
                    torch.arange(unfold.shape[-1]).unsqueeze(0).unsqueeze(1)]

    # Reshape to output shape
    out_H = (H - kernel_size) // stride + 1
    out_W = (W - kernel_size) // stride + 1
    pooled = pooled.view(N, C, out_H, out_W)

    return pooled