from genericpath import exists
import os
import shutil
from async_timeout import enum
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image

GRID_SIZE = 256

def split_image(img_dir, mask_dir, out_dir):
    os.makedirs(os.path.join(out_dir, 'image'), exist_ok=True)
    os.makedirs(os.path.join(out_dir, 'mask'), exist_ok=True)

    for name in tqdm(sorted(os.listdir(img_dir))):
        if os.path.isdir(os.path.join(img_dir, name)):
            continue
        if os.path.exists(os.path.join(out_dir, 'image', name)):
            continue
        img = np.asarray(Image.open(os.path.join(img_dir, name)))
        mask = np.asarray(Image.open(os.path.join(mask_dir, name)))
        assert img.shape[0] == mask.shape[0] and img.shape[1] == mask.shape[1]
        height, width = img.shape[0], img.shape[1]
        height_pad, width_pad = ((height-1)//GRID_SIZE+1)*GRID_SIZE, ((width-1)//GRID_SIZE+1)*GRID_SIZE
        img_pad, mask_pad = np.ones((height_pad, width_pad, 3), dtype=img.dtype)*241, np.zeros((height_pad, width_pad), dtype=mask.dtype)
        img_pad[:img.shape[0], :img.shape[1]] = img
        mask_pad[:mask.shape[0], :mask.shape[1]] = mask
        del img, mask
        height_num, width_num = height_pad//GRID_SIZE, width_pad//GRID_SIZE

        patch_num = 0
        for i, start_h in enumerate(range(0, height_pad, GRID_SIZE)):
            for j, start_w in enumerate(range(0, width_pad, GRID_SIZE)):
                mask_grid = mask_pad[start_h:start_h+GRID_SIZE, start_w:start_w+GRID_SIZE]
                if mask_grid.max() > 0:
                    patch_num += 1
                    os.makedirs(os.path.join(out_dir, 'image', name), exist_ok=True)
                    os.makedirs(os.path.join(out_dir, 'mask', name), exist_ok=True)
                    img_grid = Image.fromarray(img_pad[start_h:start_h+GRID_SIZE, start_w:start_w+GRID_SIZE])
                    mask_grid = Image.fromarray(mask_grid)
                    img_grid.save(os.path.join(out_dir, 'image', name, f'{i}_{j}_{height_num}_{width_num}.png'))
                    mask_grid.save(os.path.join(out_dir, 'mask', name, f'{i}_{j}_{height_num}_{width_num}.png'))
        if patch_num == 0:
            os.makedirs(os.path.join(out_dir, 'image', name))
            os.makedirs(os.path.join(out_dir, 'mask', name))
            for i, start_h in enumerate(range(0, height_pad, GRID_SIZE)):
                for j, start_w in enumerate(range(0, width_pad, GRID_SIZE)):
                    img_grid = Image.fromarray(img_pad[start_h:start_h+GRID_SIZE, start_w:start_w+GRID_SIZE])
                    mask_grid = Image.fromarray(mask_pad[start_h:start_h+GRID_SIZE, start_w:start_w+GRID_SIZE])
                    img_grid.save(os.path.join(out_dir, 'image', name, f'{i}_{j}_{height_num}_{width_num}.png'))
                    mask_grid.save(os.path.join(out_dir, 'mask', name, f'{i}_{j}_{height_num}_{width_num}.png'))
