import os
import sys
import random
import math
import json
import collections
import functools
import itertools
import pickle

from PIL import Image, ImageOps, ImageDraw
from IPython.display import display, clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as T
from torchvision.utils import make_grid, save_image

import numpy as np
import matplotlib.pyplot as plt


def to_image(tensor): # [0, 1]
    tensor = torch.clamp(tensor, 0, 1)
    return T.to_pil_image(tensor)

def to_image1(tensor): # [-0.5, +0.5]
    return to_image(tensor + 0.5)

def to_image2(tensor): # [-1, +1]
    return to_image(tensor / 2 + 0.5)

def from_image(image):
    return T.to_tensor(image)

def from_image1(image):
    return from_image(image) - 0.5

def from_image2(image):
    return from_image(image) * 2 - 1

def images_to_grid(images, n_rows=None, n_cols=None):
    images = list(images)
    
    assert n_rows is not None or n_cols is not None
    assert all(images[0].size == image.size for image in images[1:])
    
    images = [image.convert('RGB') for image in images]
    width, height = images[0].size
    
    if n_rows is None:
        n_rows = (len(images) - 1) // n_cols + 1
    elif n_cols is None:
        n_cols = (len(images) - 1) // n_rows + 1
    
    grid = Image.new('RGB', (width * n_cols, height * n_rows), (0, 0, 0))
    for n, image in enumerate(images):
        grid.paste(image, (width * (n % n_cols), height * (n // n_cols)))
        
    return grid

def load_json(path):
    with open(path) as f:
        return json.load(f)

def get_dataset(root, image_size, **kwargs):
    from iconflow.dataset import IconContourDataset
    return dict(train=IconContourDataset(root, image_size, split=(0.0, 0.9), as_pil_image=True, **kwargs),
                test=IconContourDataset(root, image_size, split=(0.9, 1.0), as_pil_image=True, **kwargs))

def add_text(image, text, pos=(4, 4), color=None):
    if color is None:
        color = {'L': (0,), 'RGB': (0, 0, 0)}[image.mode]
    image = image.copy()
    ImageDraw.Draw(image).text(pos, text, color)
    return image

def batch_iter(iterable, size):
    batch = []
    for item in iterable:
        batch.append(item)
        if len(batch) == size:
            yield tuple(map(list, zip(*batch)))
            batch.clear()
    if len(batch):
        yield tuple(map(list, zip(*batch)))
