import torch
from ..model import Flux
from torch import Tensor
from einops import rearrange, repeat
from .stochastic_pool2d import stochastic_pool2d
import math

def denoise_resolution(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    guidance: float = 4.0,
    # extra img tokens
    img_cond: Tensor | None = None,
):
    # this is ignored for schnell
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    #step = 0
    small = True
    # Downsample to 1/4
    #'b (h w) C -> b h w C'
    if small:
        img = rearrange(img, "b (h w) C -> b C h w", h=int(math.sqrt(img.shape[1])), w=int(math.sqrt(img.shape[1])))
        # avg pooling to (b, h/2, w/2, C)
        #img = torch.nn.functional.avg_pool2d(img, 2, 2)
        # Stochastic pooling
        img = stochastic_pool2d(img, kernel_size=2, stride=2)
        #'(b, h/2, w/2, C) -> b (h/2 w/2) C'
        img = rearrange(img, "b C h w -> b (h w) C")

    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):

        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        pred = model(
            img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
            img_ids=img_ids[1] if small else img_ids[0],
            #img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )

        img = img + (t_prev - t_curr) * pred
        #step += 1

        # After a certain stage, restore the image to its original size
        if (t_curr <= 0.5) and small:
        #if (step == 48) and small:
            img = rearrange(img, "b (h w) C -> b C h w", h=int(math.sqrt(img.shape[1])), w=int(math.sqrt(img.shape[1]))
            )
            #img = torch.nn.functional.interpolate(img, scale_factor=2, mode="bilinear", align_corners=False)
            img = torch.nn.functional.interpolate(img, scale_factor=2, mode="nearest")

            # Add noise with mean 0 and variance eps
            eps = 0.1
            noise = torch.randn_like(img) * eps
            img = img + noise

            img = rearrange(img, "b C h w -> b (h w) C")
            small = False

    return img
