import os
from tqdm import tqdm
import glob
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader

from styleaug.text_embedder import TextEmbedder
from FastCLIPstyler import FastCLIPStyler

test_prompts = [
    ("wooden cartoon", "boat"),
    ("watercolor painting", "house_painting"),
    ("moonlit darkness", "church"),
    ("swirly", "bird"),
    ("surrealism in vibrant colors", "starry"),
    ("A sketch with blue and red pencils", "rock"),
    ("Style of picasso's cubism", "flower"),
    ("A sketch with black pencil", "giraffe"),
    ("Colorful aurora evokes feelings of happiness and excitement", "car"),
    ("Great melding of colors and brush technique. Colors seem to flow across the canvas", "farmer"),
]

class params:
    img_width=512
    img_height=512
    num_crops=16
    text_encoder='fastclipstyler' # can be either fastclipstyler or edgeclipstyler

class TrainStylePredictor():
    def __init__(self):
        pass

    def test(self):

        global test_prompts

        trainer = TrainCLIPStyler(params)
        content_folder = "./content_images"

        for text, content_name in tqdm(test_prompts, total=len(test_prompts)):
            content_path = os.path.join(content_folder, f"{content_name}.jpg")
            params.content_path = content_path

            text = text.lower()
            params.text = text

            trainer.test()

style_predictor = TrainStylePredictor()
style_predictor.test()
