import torch
from ..model import Flux
from torch import Tensor
from ..modules.cache_functions import cache_init

def denoise_parallel(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    guidance: float = 4.0,
):  
    # init cache
    cache_dic, current = cache_init(timesteps)
    # this is ignored for schnell
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
    current['step']=0
    current['num_steps'] = len(timesteps)-1
            
    img_ids2 = torch.cat((img_ids[0], img_ids[0]), dim=0)
    txt2 = torch.cat((txt, txt), dim=0)
    txt_ids2 = torch.cat((txt_ids, txt_ids), dim=0)
    vec2 = torch.cat((vec, vec), dim=0)
    guidance_vec2 = torch.cat((guidance_vec, guidance_vec), dim=0)

    new_skip = True
    skipped_flag = False
    #for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):

    for i in range(len(timesteps)-1):
        if skipped_flag:
            # 跳过这个iter
            skipped_flag = False
            print(f'skipping {i}')
            continue
        t_curr, t_prev = timesteps[i], timesteps[i+1]
        t_p2 = timesteps[i+2] if i+2 < len(timesteps) else t_prev
        t_p3 = timesteps[i+3] if i+3 < len(timesteps) else t_p2
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        t_vec_next = torch.full((img.shape[0],), t_prev, dtype=img.dtype, device=img.device)
        
        if i == 0:
            pred = model(
                img=img,
                img_ids=img_ids[0],
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=t_vec,
                guidance=guidance_vec,
            )

            img = img + (t_prev - t_curr) * pred
            img_pred = img + (t_p2 - t_prev) * pred
            cache_dic['img_pred'] = img_pred
            new_skip = False

        else:
            img2 = torch.cat((img, img_pred), dim=0)
            t_vec2 = torch.cat((t_vec, t_vec_next), dim=0)
            

            pred = model(
                img=img2,
                img_ids=img_ids2,
                txt=txt2,
                txt_ids=txt_ids2,
                y=vec2,
                timesteps=t_vec2,
                guidance=guidance_vec2,
            )

            pred_now, pred_next = torch.split(pred, img.shape[0], dim=0)


            img = img + (t_prev - t_curr) * pred_now
            
            loss = torch.nn.functional.mse_loss(cache_dic['img_pred'], img)
            #print(loss)
            ## img: (b, (h w), c)
            #loss = 1 - torch.nn.functional.cosine_similarity(cache_dic['img_pred'], img, dim=-1).mean()

            if loss < 2 * 1e-6:
                #print(f'skipping {i}, loss: {loss:.10f}')
                img = img + (t_p2 - t_prev) * pred_next
                cache_dic['img_pred'] = img + (t_p3 - t_p2) * pred_next
                #new_skip = True
                skipped_flag = True
                print(f'i={i}, accepted {i+1}, loss: {loss:.10f}')

            else:
                #new_skip = False
                cache_dic['img_pred'] = img + (t_p2 - t_prev) * pred_now
                print(f'i={i}, rejected {i+1}, loss: {loss:.10f}')

        current['step'] += 1

    return img
