import numpy as np
import torch
from torch.nn.functional import interpolate
import time
from rich import print


def seg2box(seg: torch.Tensor) -> tuple[int, int, int, int]:
    assert seg.dim() == 4
    indices = seg.nonzero(as_tuple=True)
    x_1 = indices[2].min().item()
    x_2 = indices[2].max().item()
    y_1 = indices[3].min().item()
    y_2 = indices[3].max().item()

    return (x_1, y_1, x_2, y_2)


def histogram(xs: torch.Tensor, bins: int):
    # Like torch.histogram, but works with cuda
    min, max = xs.min().item(), xs.max().item()
    counts = torch.histc(xs, bins, min=min, max=max)
    boundaries = torch.linspace(min, max, bins + 1)
    return counts, boundaries


def otsu(tensor: torch.Tensor):
    """
    This is the more common (optimized) implementation of otsu algorithm, the one you see on Wikipedia pages
    """
    tensor = tensor.squeeze()
    hist, bin_edges = histogram(tensor, bins=256)
    hist = hist.cpu().numpy()
    no_of_bins = len(hist)  # should be 256

    sum_total = 0
    for x in range(0, no_of_bins):
        sum_total += x * hist[x]

    weight_background = 0.0
    sum_background = 0.0
    inter_class_variances = []

    for threshold in range(0, no_of_bins):
        # background weight will be incremented, while foreground's will be reduced
        weight_background += hist[threshold]
        if weight_background == 0:
            continue

        weight_foreground = tensor.numel() - weight_background
        if weight_foreground == 0:
            break

        sum_background += threshold * hist[threshold]
        mean_background = sum_background / weight_background
        mean_foreground = (sum_total - sum_background) / weight_foreground

        inter_class_variances.append(
            weight_background
            * weight_foreground
            * (mean_background - mean_foreground) ** 2
        )

    # find the threshold with maximum variances between classes
    try:
        return bin_edges[np.argmax(inter_class_variances)]
    except:
        return 0


def otsu_mask(tensor: torch.Tensor):
    return (tensor > otsu(tensor)).float()
