import sys, os
sys.path.append('.')
sys.path.append('./mainldm/clip-score/src')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch

from clip_score.clip_score import calculate_clip_given_paths
from evalution.sfid import test_fid_sfid
from quant.coco_prompt import get_prompts, center_resize_image

def test_stable_diffusion():
    benchmark = "coco"
    imglogdir = "./text2img-benchmark/mscoco/"
    prompt_path = f"./mainldm/results/{benchmark}/prompt"
    kwargs = dict(real_path=imglogdir, 
                    fake_path=prompt_path,
                    batch_size=100,
                    device=torch.device("cuda:0"), 
                    clip_model="ViT-L/14",
                    num_workers=8
                    )
    clip_score = calculate_clip_given_paths(**kwargs)
    print('CLIP Score:{}'.format(clip_score))

    if benchmark in ["coco", "coco1"]:
        image_resize = "./mainldm/results/{}/image_resize".format(benchmark)
        print(f"image resize save path: {image_resize}")
        center_resize_image(imglogdir, image_resize, (300, 300))
        ref_batch = './coco/coco-train.npz'
        sample_batch = './text2img-benchmark/{}-sample.npz'.format(benchmark)
        ref_path = imglogdir
        sample_path = imglogdir
        test_fid_sfid(ref_batch=ref_batch, sample_batch=sample_batch, ref_path=ref_path, sample_path=sample_path, device="cuda:3")
        print('CLIP Score:{}'.format(clip_score))

def test_sfid():
    from evalution.sfid import test_fid_sfid

    ref_batch = './imagenet-train.npz'
    sample_batch = './npz/imagenet-sample.npz'
    ref_path = "/mnt/data/liuxuewen/imagenet/train_resize"
    sample_path = "./mainldm/imagenet/"

    test_fid_sfid(ref_batch, sample_batch, ref_path=ref_path, sample_path=sample_path, device="cuda:0")

if __name__ == '__main__':
    # test_stable_diffusion()
    test_sfid()
