import os
from PIL import Image, ImageDraw, ImageFont
import matplotlib.font_manager
import random
import string
import math


def add_text_to_medical_img(src_folder_path, dst_folder_path, mask_dst_folder_path):
    """
    Generate text on medical patches.

    Parameters:
    - src_folder_path: Directory of the original patches.
    - dst_folder_path: Directory where patches will be saved.
    - mask_dst_folder_path: Directory where the mask patches will be saved.
    """

    # Get list of system fonts
    fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
    # Filter out the undesired fonts (makes problems in the code)
    fonts = [font for font in fonts if 'NotoColorEmoji.ttf' not in font]

    for index, filename in enumerate(sorted(os.listdir(src_folder_path))):
        # Load the 2D image
        file_path = os.path.join(src_folder_path, filename)  # The file path in the input folder
        img = Image.open(file_path)  # Load the file
        draw = ImageDraw.Draw(img)

        # Open a 63x63 black image for the img mask (we save each text we generate for the medical
        # image on a black patch in addition to the text on the medical image, so we could
        # compare the results to the original text in the patch
        patch_img = Image.new("RGB", img.size, "black")
        patch_draw = ImageDraw.Draw(patch_img)

        # Decide on content: 0 = text, 1 = arrow, 2 = both
        content_choice = random.choices([0, 1, 2], weights=[8, 1, 3], k=1)[0]

        if content_choice in [0, 2]:  # If content includes text
            # Set the sizes for the text and arrow
            size = random.randint(10, 16)
            length = random.randint(4, 8)

            # Set the text properties
            font = ImageFont.truetype(random.choice(fonts), size=size)
            text = ''.join(
                random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(length))
            loc_x = random.randint(5, 40)
            loc_y = random.randint(5, 30)
            color = random.randint(150, 230)

            # Draw the text
            draw.text((loc_x, loc_y), text, font=font, fill=(color, color, color))
            patch_draw.text((loc_x, loc_y), text, font=font, fill=(color, color, color))

        if content_choice in [1, 2]:  # If content includes arrow
            # Draw arrow at a random location
            arrow_start = (random.randint(10, 44), random.randint(10, 44))
            arrow_end = (arrow_start[0] + random.choice([-1, 1]) * random.randint(5, 20),
                         arrow_start[1] + random.choice([-1, 1]) * random.randint(5, 20))

            # Draw the arrow
            draw.line([arrow_start, arrow_end], fill="white", width=2)
            patch_draw.line([arrow_start, arrow_end], fill="white", width=2)
            # Draw the arrowhead
            arrow_angle = math.atan2(arrow_end[1] - arrow_start[1], arrow_end[0] - arrow_start[0])
            arrow_length = 10
            x1 = arrow_end[0] - arrow_length * math.cos(arrow_angle + math.pi / 6)
            y1 = arrow_end[1] - arrow_length * math.sin(arrow_angle + math.pi / 6)
            x2 = arrow_end[0] - arrow_length * math.cos(arrow_angle - math.pi / 6)
            y2 = arrow_end[1] - arrow_length * math.sin(arrow_angle - math.pi / 6)
            draw.polygon([(arrow_end[0], arrow_end[1]), (x1, y1), (x2, y2)], fill="white")
            patch_draw.polygon([(arrow_end[0], arrow_end[1]), (x1, y1), (x2, y2)], fill="white")

        # Save the slice image
        img.save(os.path.join(dst_folder_path, filename))
        patch_img.save(os.path.join(mask_dst_folder_path, filename))


def generate_text_on_black_patches(img_num, dst_folder_path, patch_size=(63, 63)):
    """
    Generate text on black patches.

    Parameters:
    - img_num: Number of images to generate.
    - dst_folder_path: Directory where patches will be saved.
    - patch_size: Tuple indicating the size of the patches. Default is (64, 64).
    """

   # Get list of system fonts
    fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
    # Filter out the undesired fonts (makes problems in the code)
    fonts = [font for font in fonts if 'NotoColorEmoji.ttf' not in font]

    for index in range(img_num):

        black_patch = Image.new(mode="RGB", size=(patch_size[0], patch_size[1]), color=(0, 0, 0))
        # Draw mode
        draw = ImageDraw.Draw(black_patch)

        # Decide on content: 0 = text, 1 = arrow, 2 = both
        content_choice = random.choice([0, 1, 2])

        if content_choice in [0, 2]:  # If content includes text
            # Set the sizes for the text and arrow
            size = random.randint(10, 16)
            length = random.randint(4, 8)

            # Set the text properties
            font = ImageFont.truetype(random.choice(fonts), size=size)
            text = ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(length))
            loc_x = random.randint(5, 40)
            loc_y = random.randint(5, 30)
            color = random.randint(150, 230)

            # Draw the text
            draw.text((loc_x, loc_y), text, font=font, fill=(color, color, color))

        if content_choice in [1, 2]:  # If content includes arrow
            # Draw arrow at a random location
            arrow_start = (random.randint(10, 44), random.randint(10, 44))
            arrow_end = (arrow_start[0] + random.choice([-1, 1]) * random.randint(5, 20),
                         arrow_start[1] + random.choice([-1, 1]) * random.randint(5, 20))

            # Draw the arrow
            draw.line([arrow_start, arrow_end], fill="white", width=2)
            # Draw the arrowhead
            arrow_angle = math.atan2(arrow_end[1] - arrow_start[1], arrow_end[0] - arrow_start[0])
            arrow_length = 10
            x1 = arrow_end[0] - arrow_length * math.cos(arrow_angle + math.pi / 6)
            y1 = arrow_end[1] - arrow_length * math.sin(arrow_angle + math.pi / 6)
            x2 = arrow_end[0] - arrow_length * math.cos(arrow_angle - math.pi / 6)
            y2 = arrow_end[1] - arrow_length * math.sin(arrow_angle - math.pi / 6)
            draw.polygon([(arrow_end[0], arrow_end[1]), (x1, y1), (x2, y2)], fill="white")


        # Save the slice image
        img_name = str(index) + ".png"
        black_patch.save(os.path.join(dst_folder_path, img_name))