import tensorflow as tf

# Note that tensorflow_datasets is not installed in the environment by
# default (install with `pip install 'tensorflow-datasets==3.1.*'`)
import tensorflow_datasets as tfds

IMAGENET_H = 160
IMAGENET_W = 160
IMAGENET_BATCH_SIZE = 32
IMAGENET_CLASSES = 1000
IMAGENET_PREFETCH = 2 * IMAGENET_BATCH_SIZE
IMAGENET_N_VAL = 100000


def load_imagenet(split, val_split=True):
    def preprocess(item):
        x = tf.image.convert_image_dtype(item["image"], tf.float32)
        x = tf.image.resize_with_pad(x, IMAGENET_H, IMAGENET_W)
        y = tf.one_hot(item["label"], IMAGENET_CLASSES)
        return x, y

    if split in ["train", "val"]:
        full, info = tfds.load("imagenet2012", split="train", with_info=True)
        n_full = info.splits["train"].num_examples
        if not val_split:
            if split == "val":
                raise ValueError('requested split "val", but val_split=False')
            return full.map(preprocess), n_full
        else:
            n_train = n_full - IMAGENET_N_VAL
            if split == "train":
                return full.take(n_train).map(preprocess), n_train
            else:
                return full.skip(n_train).map(preprocess), IMAGENET_N_VAL
    elif split == "test":
        test, info = tfds.load("imagenet2012", split="validation", with_info=True)
        n_test = info.splits["validation"].num_examples
        return test.map(preprocess), n_test
    else:
        raise ValueError('split must be "train", "val", or "test"')
