Segmentation Pipeline
Motivation
In our use case, lightsheet microscopy of mouse brain produces several hundreds of GBs of data, represented as 3-dimensional arrays with an extra channel axis. This array is stored as an OME ZARR file. Distributed computation is necessary to make the analysis time trackable, and we choose to use Dask as the distributed computing library for our project.
As part of our research, we need to use automated method to count objects in the image. Afterwards, we use a map that maps from pixel location of the image to brain region segmentation (this is the atlas map) to obtain the density of cells in each region of the mouse brain. We need to test several methods and find one that would give the most accurate counts, and for new incoming data volumes, we want to be able to quickly find a set of parameters that works on the new data.
On the algorithm side, object counting is easy on a Numpy array with small dataset size. On larger datasets, we need to long-running distributed computation that are hard to debug and requires hours to run the computation.
SegProcess
Below we use examples from cvpl_tools.im.process
to talk about a convenient way
to define a function in multi-step image processing pipeline for distributed, interpretable and cached image
data analysis.
Consider a function that counts the number of cells in a 3d-block of brightness map:
import dask.array as da
def cell_count(block3d: da.Array):
"""Count the number of cells in a dask 3d brightness image"""
mask = threshold(block3d, 0.4) # use simple thresholding to mark pixels where cells are as 1s
inst = instance_segmentation(mask) # segment the mask into contours
cell_cnt = count_inst(inst) # count each contour as a cell
return cell_cnt
This code is complete, but notice that:
1. Lack of interpretability. Often the first run is when some bug shows up. Debugging this becomes a problem without some way to look into intermediate results since we don’t know if one of the three steps in the cell_count function did not work as expected, or if the algorithm does not work well on the input data for some reason. In either case, if we want to find the root cause of the problem we very often end up adding code and rerun the pipeline to display the output of each step to see if they match with our expectations.
2. Result is not cached. Ideally the pipeline is run once and we get the result, but more often than not the result may need to be used in different places (visualization, analysis etc.). Caching these results makes sure computation is done only once, which is necessary when we work with costly algorithms on hundreds of GBs of data (of course it’s still best to first test on a slice of <1GB of data).
The basic idea to address 1) is to put visualization as part of the cell_count function, and to address
2) is to cache the result of each step into a file in a CacheDirectory
. In more details, we
want a image processing pipeline that provide:
1. dask-support. Inputs are expected to be either numpy array, dask array, or
cvpl.im.ndblock.NDBlock
objects. In particular, dask.Array and NDBlock are suitable for
parallel or distributed image processing workflows
2. integration of Napari. The function has an attribute context_args
that has a keyed item
viewer_args
defaults to None. By passing a Napari viewer as viewer_args["viewer"]
,
the function will add intermediate images or centroids to the Napari viewer for easier debugging.
After the function returns, we can call viewer.show()
to display all added images
3. intermediate result caching. It provides a hierarchical caching directory,
where in a call to the function it will either create a new directory, or load from existing
cache directory based on the cache_url
parameter in context_args
parameter
Now we discuss how to define a process function
Extending the Pipeline
The first step of building a pipeline is to break a segmentation algorithm down to steps that process the image in different formats. As an example, we may implement a pipeline as IN -> BS -> OS -> CC, where:
IN - Input Image (
np.float32
) between min=0 and max=1, this is the brightness dask image as inputBS - Binary Segmentation (3d,
np.uint8
), this is the binary mask single class segmentationOS - Ordinal Segmentation (3d,
np.int32
), this is the 0-N where contour 1-N each denotes an object; also single classCC - Cell Count Map (3d,
np.float64
), a cell count number (estimate, can be float) for each block
Mapping from IN to BS comes in two choices. One is to simply take threshold > some number as cells and the rest as background. Another is to use a trained machine learned algorithm to do binary segmentation. Mapping from BS to OS also comes in two choices. Either directly treating each connected volume as a separate cell, or use watershed to get finner segmentation mask. Finally, We can count cells in the instance segmentation mask by perhaps look at how many seperate contours we have found.
In some cases this is not necessary if we know what algorithm works best, but abstracting the algorithm intermediate results as four types IN, BS, OS, CC have helped us identify which part of the pipeline can be reused and which part may have variations in the algorithm used.
We can then plan the processing steps we need to define as follows:
thresholding (IN -> BS)
model_prediction (IN -> BS)
direct_inst_segmentation (BS -> OS)
watershed_inst_segmentation (BS -> OS)
cell_cnt_from_inst (OS -> CC)
How do we go from this plan to actually code these steps? For each step, we define a function process()
,
which takes arbitrary inputs and one parameter: context_args
, which will contain keyed items as follows:
cache_url (str | RDirFileSystem, optional): Pointing to a directory to store the cached image; if not provided, then the image will be cached via dask’s persist() and its loaded copy will be returned
storage_option (dict, optional): If provided, specifies the compression method to use for image chunks
preferred_chunksize (tuple, optional): Re-chunk before save; this rechunking will be undone in load
multiscale (int, optional): Specifies the number of downsampling levels on OME ZARR
compressor (numcodecs.abc.Codec, optional): Compressor used to compress the chunks
viewer_args (dict, optional): If provided, an image will be displayed as a new layer in Napari viewer
viewer (napari.Viewer, optional): Only display if a viewer is provided
is_label (bool, optional): defaults to False; if True, use viewer’s add_labels() instead of add_image() to display the array
layer_args (dict, optional): If provided, used along with viewer_args to specify add_image() kwargs
As a convention,
context_args
containscache_url
which is required only if the function needs some place to store intermediate results:async def process(im, context_args: dict): cache_url = context_args['cache_url'] query = tlfs.cdir_commit(cache_url) # in the case cache does not exists, cache_path.url is an empty path we can create a folder in: if not query.commit: result = compute_result(im) save(cache_url, result) result = load(cache_url) return result
The
viewer_args
parameter specifies the napari viewer to display the intermediate results. If not provided (viewer_args=None
), then no computation will be done to visualize the image. Within the forward() method, you should useviewer.add_labels()
ortlfs.cache_im()
while passing inviewer_args
argument to display your results:async def process(im, context_args): result = compute_result(im) result = await tlfs.cache_im(lambda: result, context_args=dict( cache_url=context_args.get('cache_url'), viewer_args=context_args.get('viewer_args'))) return result # ... viewer = napari.Viewer(ndisplay=2) viewer_args = dict( viewer=viewer, # The napari viewer, visualization will be skipped if viewer is None is_label=True, # If True, viewer.add_labels() will be called; if False, viewer.add_image() will be called preferred_chunksize=(1, 4096, 4096), # image will be converted to this chunksize when saved, and converted back when loaded multiscale=4, # maximum downsampling level of ome zarr files, necessary for very large images ) context_args = dict( cache_url='gcs://example/cloud/path', viewer_args=viewer_args ) await process(im, context_args=context_args)
viewer_args
is a parameter that allows us to visualize the saved results as part of the caching function. The reason we need this is that displaying the saved result often requires a different (flatter) chunk size for fast loading of cross-sectional image, in the above example it is converted from the original chunk size e.g. (256, 256, 256) to (1, 4096, 4096) and also requires downsampling for zooming in/out of larger images, which the built-in persist() function of dask library does not provide good support of.
Running the Pipeline
See Setting Up the Script to understand boilerplate code used below, required to understand the following example.
With a process
function defined, the next step is to write our script that uses the pipeline
to segment an input dataset. Note we need a dask cluster and a temporary directory setup before running the
forward()
method.
if __name__ == '__main__': # Only for main thread, worker threads will not run this
TMP_PATH = "path/to/temporary/directory"
import dask
from dask.distributed import Client
import napari
with dask.config.set({'temporary_directory': TMP_PATH}:
temp_directory = f'{TMP_PATH}/CacheDirectory'
im = load_im(path) # this is our input dask.Array object to be segmented
viewer = napari.Viewer()
viewer_args = dict(viewer=viewer)
context_args = dict(
cache_url=f'{temp_directory}/example_seg_process',
viewer_args=viewer_args
)
await example_seg_process(im, context_args=context_args)
client.close()
viewer.show(block=True)
If instead viewer_args=None
is passed the example_seg_process()
function will process
the image and cache it, but displays nothing.
A process function has signature
process(arg1, ..., argn, context_args)
, where arg1 to n are arbitrary arguments andcontext_args
is a dictionaryFor parameters that changes how the viewer displays the image, these parameters are provided through the
viewer_args
argument of thecontext_args
dictionary.For parameters that specifies how the image is cached and stored locally (storing is often required for display), these parameters are provided through the
storage_options
argument of thecontext_args
dictionary.
To learn more, see the API pages for cvpl_tools.im.process
, cvpl_tools.tools.fs
and
cvpl_tools.im.ndblock
modules.