Source code for empanada.zarr_utils

import math
import numba
import numpy as np
from multiprocessing import Pool
from empanada.array_utils import put, rle_to_ranges, ranges_to_rle

__all__ = [
    'zarr_fill_instances'
]

@numba.jit(nopython=True)
def chunk_ranges(ranges, modulo, divisor):
    # shift the ranges based on a divisor
    # this will give us the chunk indices
    # for starts and ends of each range
    ci_s = (ranges[:, 0] % modulo) // divisor 
    # ranges exclude the last index here
    ci_e = ((ranges[:, 1] - 1) % modulo) // divisor

    # check where chunk indices are not
    # the same between start and end
    chunked_ranges = []
    for cs, ce, r in zip(ci_s, ci_e, ranges):
        # split or keep the range
        if cs != ce or (r[1] - r[0] > divisor):
            # convert from a range to a list of indices
            # add one because ranges are non-inclusive endpoints
            range_indices = np.arange(r[0], r[1] + 1, dtype=ranges.dtype)
            cr_indices = [(ri % modulo) // divisor for ri in range_indices]
            split_idx = [0]
            last_c = cr_indices[0]
            for i,c in enumerate(cr_indices[1:], 1):
                if c != last_c:
                    split_idx.append(i)
                    last_c = c
                
            # append last index to create consistent ranges
            if split_idx[-1] != len(range_indices) - 1:
                split_idx.append(-1)
                
            for i, j in zip(split_idx[:-1], split_idx[1:]):
                chunked_ranges.append([range_indices[i], range_indices[j]])
        else:
            # no need to split the range
            chunked_ranges.append([r[0], r[1]])
    
    return chunked_ranges

@numba.jit(nopython=True)
def fill_func(seg1d, coords, instance_id):
    r"""Fills coords in seg1d (raveled image) with value instance_id"""
    # inplace fill seg1d with instance_id
    # at the given xy raveled coords
    for coord in coords:
        s, e = coord
        seg1d[s:e] = instance_id

    return seg1d

def fill_zarr_mp(*args):
    r"""Helper function for multiprocessing the filling of zarr slices"""
    # fills zarr array with multiprocessing
    slices, instances, array = args[0]
    z, y, x = [sl.start for sl in slices]
    
    seg = array[slices]
    chunk_shape = seg.shape
    
    seg = seg.reshape(-1)
    for instance_id, ranges in instances.items():
        zs, ys, xs = np.unravel_index(ranges[:, 0], array.shape)
        ze, ye, xe = np.unravel_index(ranges[:, 1] - 1, array.shape)

        zs -= z
        ze -= z
        ys -= y
        ye -= y
        xs -= x
        xe -= x

        starts = np.ravel_multi_index((zs, ys, xs), chunk_shape)
        ends = np.ravel_multi_index((ze, ye, xe), chunk_shape)
        
        fill_func(seg, np.stack([starts, ends + 1], axis=1), int(instance_id))

    array[slices] = seg.reshape(chunk_shape)

[docs]def zarr_fill_instances(array, instances, processes=4): r"""Fills a zarr array in-place with instances. Args: array: zarr.Array of size (d, h, w) instances: Dictionary. Keys are instance_ids (integers) and values are another dictionary containing the run length encoding (keys: 'starts', 'runs'). processes: Integer, the number of processes to run. """ # make sure number of processes doesn't exceed chunks processes = min(array.nchunks, processes) d, h, w = array.shape dc, hc, wc = array.chunks cd, ch, cw = math.ceil(d / dc), math.ceil(h / hc), math.ceil(w / wc) # enumerate and define ROIs for each chunk chunks = {} for i, z1 in enumerate(range(0, d, dc)): for j, y1 in enumerate(range(0, h, hc)): for k, x1 in enumerate(range(0, w, wc)): z2 = min(d, z1 + dc) y2 = min(h, y1 + hc) x2 = min(w, x1 + wc) zslice = slice(z1, z2) yslice = slice(y1, y2) xslice = slice(x1, x2) chunk_idx = (i * ch * cw) + (j * cw) + k chunks[chunk_idx] = { 'slices': (zslice, yslice, xslice), 'instances': {} } chunked_instances = {} for instance_id, instance_attrs in instances.items(): # convert from runs to ranges ranges = rle_to_ranges(np.stack( [instance_attrs['starts'], instance_attrs['runs']], axis=1 )) zmodulo = d * h * w zdivisor = dc * h * w ranges = np.array(chunk_ranges(ranges, zmodulo, zdivisor)) ymodulo = h * w ydivisor = hc * w ranges = np.array(chunk_ranges(ranges, ymodulo, ydivisor)) xmodulo = w xdivisor = wc ranges = np.array(chunk_ranges(ranges, xmodulo, xdivisor)) # get chunk_ids for each range zc = (ranges[:, 0] % zmodulo) // zdivisor yc = (ranges[:, 0] % ymodulo) // ydivisor xc = (ranges[:, 0] % xmodulo) // xdivisor chunk_indices = (zc * ch * cw) + (yc * cw) + xc # group ranges by the chunk in which they reside sort_idx = np.argsort(chunk_indices) ranges = ranges[sort_idx] chunk_indices = chunk_indices[sort_idx] unq_chunks, unq_idx = np.unique(chunk_indices, return_index=True) chunked_ranges = np.split(ranges, unq_idx[1:]) for chunk_id, cranges in zip(unq_chunks, chunked_ranges): chunks[chunk_id]['instances'][instance_id] = cranges # setup for multiprocessing n = len(chunks) arg_iter = zip( [cdict['slices'] for cdict in chunks.values()], [cdict['instances'] for cdict in chunks.values()], [array] * n ) # fill the zarr volume. nothing to # return because it's done inplace with Pool(processes) as pool: pool.map(fill_zarr_mp, arg_iter)