import sys, os
from .test import *
from .hist import draw_style
from tqdm import tqdm
import glob, ei.patched


sub = sys.argv[1]
gpu = sys.argv[2]
print(sub, gpu)

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = gpu

device = 'cuda'

def load_flow():
    from iconflow.train_flow import load_method
    flow = load_method(sub=sub)
    return flow.to(device).eval()

flow = load_flow()

ci_list = [
    34, 43, 67, 108, 134, 214, 244, 274, 309, 368, 430, 438, 480,
    540, 560, 577, 589, 664, 682, 687, 691, 696, 717, 754, 793, 798,
    803, 829, 832, 855, 881, 903, 900, 919, 923, 928, 1006, 1012, 1077,
    1097, 1116, 1154, 1163
]


RESOLUTION = 128
dataset = get_dataset('datasets/icon/data/in_memory', RESOLUTION, contour_width='')
_, cs = zip(*(dataset['test'][i] for i in tqdm(ci_list)))

from utils.style import load_style_list, get_style_img

style_dict = collections.OrderedDict(
    (name.replace('/', ':'), (pos, cmb))
    for name, (pos, cmb) in load_style_list()
)


ts = [0.3, 0.5, 0.7]
n = 5

out_dir = f'output/images/{sub}'
os.makedirs(out_dir, exist_ok=True)

for name, (pos, cmb) in style_dict.items():
    out_path = os.path.join(out_dir, f'{name}.png')
    outs = [[get_style_img(cmb, name)] for _ in cs]
    for t in tqdm(ts, desc=name):
        rs_list = []
        for bc, in batch_iter(zip(cs), 100):
            rs_list += flow.eval().sample(bc, [pos] * len(bc), n, t)
        for i, (ci, rs) in enumerate(zip(ci_list, rs_list)):
            outs[i] += [
                add_text(r.resize((128, 128), Image.BICUBIC), f'{ci} T{t:.1f}') for r in rs
            ]
    outs = [r for out in outs for r in out]
    grid = images_to_grid(outs, n_cols=len(ts*n)+1)
    grid.save(out_path)
