import torch
import math
from torch import Tensor
from einops import rearrange, repeat

def split(img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor):
    # img reshape: (b, h*w, c) -> (b, c, h, w)
    img = rearrange(img, "b (h w) C -> b C h w", h=int(math.sqrt(img.shape[1])), w=int(math.sqrt(img.shape[1])))
    # img reshape: (b, c, h, w) -> (4 * b, c, h/2, w/2)
    img = rearrange(img, 'b c (ph h) (pw w) -> (b ph pw) c h w', ph=2, pw=2)
    # img reshape: (4 * b, c, h/2, w/2) -> (4 * b, hw/4, c)
    img = rearrange(img, 'B c h w -> B (h w) c')

    ## img_ids reshape: (b, h*w, 3) -> (b, h, w, 3)
    #img_ids = rearrange(img_ids, "b (h w) C -> b C h w", h=int(math.sqrt(img_ids.shape[1])), w=int(math.sqrt(img_ids.shape[1])))
    ## img_ids reshape: (b, c, h, w) -> (4 * b, c, h/2, w/2)
    #img_ids = rearrange(img_ids, 'b c (h ph) (w pw) -> (b ph pw) c h w', ph=2, pw=2)
    ## img_ids reshape: (4 * b, c, h/2, w/2) -> (4 * b, hw/4, c)
    #img_ids = rearrange(img_ids, 'B c h w -> B (h w) c')

    # img_ids repeat: (b, t, c) -> (4 * b, t, c)
    img_ids = repeat(img_ids, 'b t c -> (b ph pw) t c', ph=2, pw=2)

    # txt_ids repeat: (b, t, c) -> (4 * b, t, c)
    txt_ids = repeat(txt_ids, 'b t c -> (b ph pw) t c', ph=2, pw=2)
    # txt repeat: (b, t, c) -> (4 * b, t, c)
    txt = repeat(txt, 'b t c -> (b ph pw) t c', ph=2, pw=2)
    # vec repeat: (b, t) -> (4 * b, t)
    vec = repeat(vec, 'b t -> (b ph pw) t', ph=2, pw=2)

    return img, img_ids, txt, txt_ids, vec

def merge(img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor):
    # img reshape: (4 * b, hw/4, c) -> (4 * b, c, h/2, w/2)
    img = rearrange(img, 'B (h w) c -> B c h w', h=int(math.sqrt(img.shape[1])), w=int(math.sqrt(img.shape[1])))   
    # img reshape: (4 * b, c, h/2, w/2) -> (b, c, h, w)
    img = rearrange(img, '(b ph pw) c h w -> b c (ph h) (pw w)', ph=2, pw=2)
    # img reshape: (b, c, h, w) -> (b, h*w, c)
    img = rearrange(img, "b c h w -> b (h w) c")

    #img_ids = rearrange(img_ids, 'B (h w) c -> B c h w', h=int(math.sqrt(img_ids.shape[1])), w=int(math.sqrt(img_ids.shape[1])))
    #img_ids = rearrange(img_ids, '(b ph pw) c h w -> b c (h ph) (w pw)', ph=2, pw=2)
    #img_ids = rearrange(img_ids, "b C h w -> b (h w) C")
    
    # img_ids cut: (4 * b, t, c) -> (b, t, c)
    img_ids = rearrange(img_ids, '(b ph pw) t c -> b (ph pw) t c', ph=2, pw=2)[:, 0, :, :]

    # txt_ids cut: (4 * b, t, c) -> (b, t, c)
    txt_ids = rearrange(txt_ids, '(b ph pw) t c -> b (ph pw) t c', ph=2, pw=2)[:, 0, :, :]

    # txt cut: (4 * b, t, c) -> (b, t, c)
    txt = rearrange(txt, '(b ph pw) t c -> b (ph pw) t c', ph=2, pw=2)[:, 0, :, :]

    # vec cut: (4 * b, t) -> (b, t)
    vec = rearrange(vec, '(b ph pw) t -> b (ph pw) t', ph=2, pw=2)[:, 0, :]
    return img, img_ids, txt, txt_ids, vec



