import clip

import os

from render_wrapper import Renderer
from options import RendererOptions

import torch

import utils.general as utils
from utils import rend_utils

import tqdm
def init_clip():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    return model, preprocess

if __name__ == '__main__':
    exp_dir = '../ablation_data'
    exp_name = 'cat'
    text = 'a cat wearing a chef hat'
    num_rand_poses = 30

    exp_dir = os.path.join(exp_dir, exp_name)
    opts_text = RendererOptions()
    opts_text.exp_dir = exp_dir
    opts_text.ckpt_path = os.path.join(exp_dir, '{}_text.pth'.format(exp_name))

    otps_sked = RendererOptions()
    otps_sked.exp_dir = exp_dir
    otps_sked.ckpt_path = os.path.join(exp_dir, '{}_sked.pth'.format(exp_name))

    renderer_text = Renderer(opts_text)
    renderer_sked = Renderer(otps_sked)

    clip_model, clip_preprocess = init_clip()


    for e in tqdm.trange(10):
        images_text = renderer_text.get_image(output_type = 'rgb', disable_bg = False, num_poses = num_rand_poses, return_pil = True)
        images_sked = renderer_sked.get_image(output_type = 'rgb', disable_bg = False, num_poses = num_rand_poses, return_pil = True)

        for i in range(num_rand_poses):
            images_text[i].save(os.path.join(exp_dir, 'text_{}.png'.format(i)))
            images_sked[i].save(os.path.join(exp_dir, 'sked_{}.png'.format(i)))
        #compute clip similarity
        images_text = torch.stack([clip_preprocess(image).cuda() for image in images_text])
        images_sked = torch.stack([clip_preprocess(image).cuda() for image in images_sked])

        with torch.no_grad():
            text_features = clip_model.encode_text(clip.tokenize(text).to(renderer_text.opts.device))


            image_features_text = clip_model.encode_image(images_text)
            image_features_sked = clip_model.encode_image(images_sked)

            image_features_text /= image_features_text.norm(dim=-1, keepdim=True)
            image_features_sked /= image_features_sked.norm(dim=-1, keepdim=True)

            text_features /= text_features.norm(dim=-1, keepdim=True)

            text_sim = (100.0 * image_features_text @ text_features.T).mean()
            sked_sim = (100.0 * image_features_sked @ text_features.T).mean()
            #cosine similarity
            #text_similiarity = torch.cosine_similarity(text_features, image_features_text, dim = -1).mean()
            #sked_similiarity = torch.cosine_similarity(text_features, image_features_sked, dim = -1).mean()
        print('Experiment: {}'.format(e))
        print('sked similarity: {}'.format(sked_sim))
        print('text similarity: {}'.format(text_sim))


