import os

import cv2
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from matplotlib import pyplot as plt
from torch.nn import functional as F
from torchvision import transforms, utils as vutils
from torchvision.transforms.functional import adjust_contrast


def load_image2(img_path, img_height=None, img_width=None):
    
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    if img_width is not None:
        image = cv2.resize(image, (img_height, img_width))
    
    transform = transforms.Compose([transforms.ToTensor(),])
    image = transform(image)[:3, :, :]

    return image


def clip_normalize(image,device):
    image = F.interpolate(image,size=224,mode='bicubic', align_corners=False)
    mean=torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)
    std=torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)
    mean = mean.view(1,-1,1,1)
    std = std.view(1,-1,1,1)

    image = (image-mean)/std
    return image

def save_img(target, text):

    output_image = target.clone()
    output_image = torch.clamp(output_image, 0, 1)
    output_image = adjust_contrast(output_image, 1.5)

    final_img = output_image[0].permute(1, 2, 0).cpu().data.numpy()
    final_img *= 255.0
    final_img = cv2.copyMakeBorder(final_img, 100, 0, 0, 0, cv2.BORDER_CONSTANT)

    print_text = f"Style: {text}"
    final_img = cv2.putText(final_img, print_text, (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 1, cv2.LINE_AA)

    final_img = cv2.cvtColor(final_img, cv2.COLOR_RGB2BGR)

    if not os.path.exists(f"outputs/"):
        os.makedirs(f"outputs/")

    cv2.imwrite(f"./outputs/{text}", final_img)
