Sparse Layered Graphs for Multi-Object Segmentation, CVPR 2020 Paper
Author: Niels Jeppesen (niejep@dtu.dk)
In this notebook we prepare the data used for our comparison of two graph-based segmentation methods.
We start by loading the modules we'll be using.
import os
from zipfile import ZipFile
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import morphology, filters
from skimage.io import imread
from tqdm import tqdm
We'll use the image set BBBC038v1, available from the Broad Bioimage Benchmark Collection [Ljosa et al., Nature Methods, 2012]. The dataset contains images of nuclei and was used for the Kaggle 2018 Data Science Bowl. We will be using images from the training set provided in the stage1_train.zip file. This training data contains contains 670 images, each with a set of ground truth instance segmentation masks.
From the Data Description: This dataset contains a large number of segmented nuclei images. The images were acquired under a variety of conditions and vary in the cell type, magnification, and imaging modality (brightfield vs. fluorescence). The dataset is designed to challenge an algorithm's ability to generalize across these variations.
Download the stage1_train.zip file and place it in a directory called nuclei_comparison_data.
To avoid extracting the large amount of files on the disk, we simply read them directly into memory from the zip file.
comparison_dir = './nuclei_comparison_data'
# File path.
zip_path = os.path.join(comparison_dir, 'stage1_train.zip')
# Open zip file.
zip_file = ZipFile(zip_path)
# Create dictionary for data.
data_dic = {}
mask_dic = {}
# Read data from zip file.
for file_name in tqdm(zip_file.namelist()):
if file_name.endswith('.png'):
with zip_file.open(file_name) as img:
img_name = file_name.split('/')[0]
if img_name not in mask_dic:
mask_dic[img_name] = []
# Read image data from file.
img = imread(img)
# Put data or mask in respective dictionaries.
if '/images/' in file_name:
data_dic[img_name] = img
elif '/masks/' in file_name:
mask_dic[img_name].append(img.astype(np.bool))
print(len(data_dic), 'images read into dictionary.')
print(sum(len(mask_dic[k]) for k in mask_dic), 'image masks into dictionary.')
# All image names as list.
names = list(data_dic.keys())
Let's have a look at some of the images and some of the correspondong ground truth segmentation masks.
# Images and masks to show.
show_image_count = 10
show_mask_count = 10
# For each image...
for i, name in enumerate(data_dic):
if i >= show_image_count:
break
fig = plt.figure(figsize=(20, 2))
fig.suptitle(name)
# Show image.
ax = plt.subplot(1, show_mask_count + 1, 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(data_dic[name])
# For each image mask...
for j, mask in enumerate(mask_dic[name]):
if j >= show_mask_count:
break
# Show mask.
ax = plt.subplot(1, show_mask_count + 1, 2 + j)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(mask)
plt.show()
The ground truth masks are not perfect and contain a some incorrect segmentations. At least two public corrected datasets exist (here and here). In this notebook we will be using the original training data, however, we will correct one obvious type of errors, holes in the masks.
# Fill holes in all masks.
for name in tqdm(mask_dic):
for mask in mask_dic[name]:
morphology.binary_fill_holes(mask, output=mask)
For our segmentation method, we need initialization for each object (nuclei) in the form of an approximate center position. We could manually annotate these, or use another method for finding the centers. However, since we have the ground truth segmentations for the dataset, the easiet way to get approximate centers is just to find the center of mass for each instance mask.
Since all points in the mask are weighed equally, finding the center of mass for each object is simple.
# Generate centers from segmentations.
center_dic = {}
for name in tqdm(mask_dic):
centers = []
for mask in mask_dic[name]:
center = np.mean(np.where(mask), axis=-1)
centers.append(center)
center_dic[name] = np.asarray(centers)
It could be argued that calculating the center based on the ground truth segmentation, as described above is somewhat cheating. To better simulate human center annotations, we can pick positions randomly within the mask, however not too close to the edge of the mask.
# Generate approximate centers from segmentations.
# Pick random seed.
np.random.seed(42)
center_approx_dic = {}
for name in tqdm(names):
centers = []
center_only = np.empty(mask_dic[name][0].shape, dtype=np.bool)
for mask, center in zip(mask_dic[name], center_dic[name]):
if np.count_nonzero(mask) < 2:
# Very small mask, keep center of mass as center.
centers.append(center)
continue
# Create mask with only the center of mass.
center_only[:] = False
center_only[tuple(np.round(center).astype(np.int32))] = True
# Get mask index information.
indices = np.asarray(np.where(mask))
min_indices = np.maximum(0, np.min(indices, axis=-1) - 1)
max_indices = np.max(indices, axis=-1) + 2
# Crop.
mask_crop = mask[min_indices[0]:max_indices[0], min_indices[1]:max_indices[1]]
center_only_crop = center_only[min_indices[0]:max_indices[0], min_indices[1]:max_indices[1]]
# Create probability map using Gaussian filter.
p = filters.gaussian_filter(center_only_crop.astype(np.float64), sigma=np.array(center_only_crop.shape) / 20, mode='constant')
# Take points inside mask.
p = p[mask_crop]
# Make sure probabities sum to 1.
if np.sum(p) == 0:
p += 1
p /= np.sum(p)
# Flatten indices.
indices_flat = np.moveaxis(indices, 0, -1).reshape(-1, indices.shape[0])
# Randomly pick approx. center based on propabilities.
center_approx = indices_flat[np.random.choice(np.arange(len(indices_flat)), p=p)]
centers.append(center_approx)
center_approx_dic[name] = np.asarray(centers)
To make sure we've found the centers correctly, let's plot the center on top of some of the images. Yellow dots are center of mass, while red dots are approx. center of mass.
def show_images(data_dic, center_dic, center_dic_2=None, rows=1, columns=4):
# Images to show.
fig = plt.figure(figsize=(20, 5 * rows))
# For each image...
for i, name in enumerate(data_dic):
if i >= rows * columns:
break
centers = center_dic[name]
# Show image and centers.
ax = plt.subplot(rows, columns, i + 1, title='...' + name[-8:])
ax.imshow(data_dic[name])
if center_dic_2 is not None and name in center_dic_2:
centers_2 = center_dic_2[name]
ax.scatter(centers_2[:, 1], centers_2[:, 0], c='y', s=5)
ax.scatter(centers[:, 1], centers[:, 0], c='r', s=5)
plt.show()
show_images(data_dic, center_approx_dic, center_dic, rows=4)
Before we start segmenting data, we need to do a bit of preprocessing.
The first thing to notice about the images is that the nuclei are dark for some images and bright for other. Since we will be using the gradients to determine the nuclei boundaries, we need to make sure the boundaries are consistently defined for all images. Secondly, to have smooth gradients, we'll smooth the images a little using a standard Gaussian filter. Lastly, we'll normalize the image data.
To make images consistent, we negate the intensity of the images where nuclei are dark. In the dataset they are dark for pruple colored images, and bright for the grayscale images.
data_prep_dic = {}
for name in tqdm(data_dic):
data = data_dic[name].astype(np.float32)
if np.all(data[..., :3] == data[..., :1]):
if np.mean(data) / np.max(data) < 0.5:
# This is a grayscale and dark image. Let's take the negative values.
data = -data
# Convert to grayscale image by taking the mean of all channels.
data_prep_dic[name] = np.mean(data, axis=-1)
Let's take a look at the grayscale images. There are better ways of ensuring a high contrast in the color images when converting them to grayscale than taking the mean value. However, that is not the focus of this notebook.
show_images(data_prep_dic, center_approx_dic, rows=2)
Since we will be using simple image gradients for our segmentation, we need to make sure they're relatively smooth or the results will be bad. By smoothing the images a little, we the gradients will provide a much better segmentation. After smoothing, we normalize all images so that they have float values between 0 and 1.
# Smooth all images using a small Gaussian filter.
for name in tqdm(data_prep_dic):
# Smooth.
data = filters.gaussian_filter(data_prep_dic[name], sigma=.5)
# Normlize.
data -= np.min(data)
data /= np.max(data)
data_prep_dic[name] = data
show_images(data_prep_dic, center_approx_dic, rows=10)
Now that all the data is ready, let's save it to Numpy files that we can use for segmentation in Part 2.
# Save prepared data.
np.savez_compressed(os.path.join(comparison_dir, 'data_dic.npz'), **data_dic)
np.savez_compressed(os.path.join(comparison_dir, 'data_prep_dic.npz'), **data_prep_dic)
np.savez_compressed(os.path.join(comparison_dir, 'mask_dic.npz'), **mask_dic)
np.savez_compressed(os.path.join(comparison_dir, 'center_dic.npz'), **center_dic)
np.savez_compressed(os.path.join(comparison_dir, 'center_approx_dic.npz'), **center_approx_dic)