import torch
from ..model import Flux
from .split_utils import split, merge
from torch import Tensor

def denoise_split(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    guidance: float = 4.0,
    # extra img tokens
    img_cond: Tensor | None = None,
):
    # this is ignored for schnell
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

    step = 0
    split_step = 40
    interval = 2

    #img, img_ids, txt, txt_ids, vec = split_prepare(img, img_ids, txt, txt_ids, vec)
    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
        
        small = not ((step % interval == 0) or (step < 30))
        if small:
            img, img_ids[1], txt, txt_ids, vec = split(img, img_ids[1], txt, txt_ids, vec)

        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        pred = model(
            img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
            img_ids=img_ids[1] if small else img_ids[0],
            #img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )
        
        img = img + (t_prev - t_curr) * pred
        step += 1

        if small:
            img, img_ids[1], txt, txt_ids, vec = merge(img, img_ids[1], txt, txt_ids, vec)
        #if step == split_step:
        #    small = True
        #    img, img_ids[1], txt, txt_ids, vec = split(img, img_ids[1], txt, txt_ids, vec)

    #img, img_ids[1], txt, txt_ids, vec = merge(img, img_ids[1], txt, txt_ids, vec)
    
    return img