Source code for empanada.consensus

import math
import numpy as np
import networkx as nx
from itertools import combinations
from empanada.array_utils import *

MIN_OVERLAP = 100
MIN_IOU = 1e-2

[docs]def average_edge_between_clusters(G, cluster1, cluster2, key='iou'): r"""Calculates the average edge weight between two groups of nodes in a graph. Args: G: nx.Graph containing the nodes in cluster1 and cluster2 cluster1: List of nodes in G cluster2: List of nodes in G key: Name of the edge weight. Returns: avg_weight: Float, the average edge weight across clusters. """ weights = [] # get pairwise edge weights for node1 in cluster1: for node2 in cluster2: weights.append( G[node1][node2][key] if G.has_edge(node1, node2) else 0 ) return sum(weights) / len(weights)
[docs]def create_graph_of_clusters(G, cluster_iou_thr): r"""Creates a graph in which each node is a group of nodes that have IoU greater than cluster_iou_thr. Args: G: nx.Graph containing detection nodes cluster_iou_thr: Minimum IoU score between nodes for them to be put into the same group. Returns: cluster_graph: nx.Graph containing the grouped detection nodes. Nodes are groups of detections and edges denote groups that have overlap with each other. """ # create new graph with low iou edges dropped H = G.copy() for (u, v, d) in G.edges(data=True): if d['iou'] <= cluster_iou_thr: H.remove_edge(u, v) # each cluster is a connected component in the new graph cluster_graph = nx.Graph() for i,cluster in enumerate(nx.connected_components(H)): cluster_graph.add_node(i, cluster=cluster) # add edges weighted by average edge weight # in the non-cluster graph for node1,node2 in combinations(cluster_graph.nodes, 2): cluster1 = cluster_graph.nodes[node1]['cluster'] cluster2 = cluster_graph.nodes[node2]['cluster'] # get edge weights iou_weight = average_edge_between_clusters(G, cluster1, cluster2, 'iou') overlap_weight = average_edge_between_clusters(G, cluster1, cluster2, 'overlap') if iou_weight > MIN_IOU or overlap_weight > MIN_OVERLAP: cluster_graph.add_edge(node1, node2, iou=iou_weight, overlap=overlap_weight) return cluster_graph
[docs]def push_cluster(G, src, dst): r""" Merges groups from two nodes in a cluster_graph and removes their edge. """ src_cluster = G.nodes[src]['cluster'] G.nodes[dst]['cluster'] = G.nodes[dst]['cluster'].union(src_cluster) G.remove_edge(src, dst) return G
[docs]def merge_clusters(G): r"""Merges together clusters in the cluster graph iteratively. Args: G: nx.Graph containing nodes that represent groups of detections. Returns: H: nx.Graph containing nodes that represent the merged groups of detections. """ # copy to avoid inplace changes H = G.copy() while len(H.edges()) > 0: # most connected from sorted nodes by the number of neighbors most_connected = sorted( H.nodes, key=lambda x: len(list(H.neighbors(x))), reverse=True )[0] # sort neighbors by the size of their clusters neighbors = sorted( H.neighbors(most_connected), key=lambda x: len(H.nodes[x]['cluster']), reverse=True ) # decide whether to push the most connected cluster to # merge with its neighbors or to merge all the neighbors # into the most connected cluster most_connected_cluster = H.nodes[most_connected]['cluster'] # if a neighbor has a bigger cluster then push most connected push_most_connected = len(H.nodes[neighbors[0]]['cluster']) > len(most_connected_cluster) if push_most_connected: # most connected cluster is rejected as an instance for neighbor in neighbors: push_cluster(H, most_connected, neighbor) H.remove_node(most_connected) else: # most connected cluster is accepted as an instance # pull all the neighboring clusters for neighbor in neighbors: push_cluster(H, neighbor, most_connected) # push secondary neighbors to most connected node second_neighbors = list(H.neighbors(neighbor)) for sn in second_neighbors: if not H.has_edge(most_connected, sn): edge_iou = H[neighbor][sn]['iou'] H.add_edge(most_connected, neighbor, iou=edge_iou) H.remove_node(neighbor) return H
[docs]def merge_instances(instances_dict): r"""Merge arbitrary number of instances. From dict of instance_id and instance_attrs. """ if len(instances_dict) < 2: return list(instances_dict.values())[0] merged_box, merged_starts, merged_runs = None, None, None for instance_attrs in instances_dict.values(): if merged_box is None: merged_box = instance_attrs['box'] merged_starts = instance_attrs['starts'] merged_runs = instance_attrs['runs'] else: merged_box = merge_boxes(merged_box, instance_attrs['box']) merged_starts, merged_runs = merge_rles( merged_starts, merged_runs, instance_attrs['starts'], instance_attrs['runs'] ) return dict(box=merged_box, starts=merged_starts, runs=merged_runs)
[docs]def merge_overlapping(cluster_instances): r"""Merges together instances that have non-trivial overlap with each other. """ # only applies when more than 1 instance in a cluster if len(cluster_instances) < 2: return list(cluster_instances.values()) # resolve overlaps between cluster instances instance_ids = list(cluster_instances.keys()) merge_graph = nx.Graph() merge_graph.add_nodes_from(instance_ids) # measure intersection between all pairs of instances for c_i,c_j in combinations(instance_ids, 2): pair_iou, inter_area = rle_iou( cluster_instances[c_i]['starts'], cluster_instances[c_i]['runs'], cluster_instances[c_j]['starts'], cluster_instances[c_j]['runs'], return_intersection=True ) if pair_iou > MIN_IOU or inter_area > MIN_OVERLAP: merge_graph.add_edge(c_i, c_j) merged_instances = [] for comp in nx.connected_components(merge_graph): comp_instances = {k: v for k,v in cluster_instances.items() if k in comp} merged_instances.append(merge_instances(comp_instances)) return merged_instances
[docs]def bounding_box_screening(boxes, source_indices): r"""Merges together clusters in the cluster graph iteratively. Args: boxes: Array of size (n, 4) or (n, 6) where bounding box is defined as (y1, x1, y2, x2) or (z1, y1, x1, z2, y2, x2). source_indices: Array of size (n,) that records the source of each bounding box. Bounding boxes from the same source are always screened. Returns: box_matches: Array of size (k, 2). Each item is a unique pair of bounding boxes from boxes that have non-trivial overlap with each other. """ # compute pairwise overlaps for all distance boxes # TODO: replace pairwise intersection calculation with something # more memory efficient (only matters when N is large ~10^4) box_ious, box_overlap = box_iou(boxes, return_intersection=True) # use small value to weed out really trivial overlaps box_matches = np.array( np.where(np.logical_or(box_overlap > MIN_OVERLAP, box_ious > MIN_IOU)) ).T # filter out boxes from the same source (mask or tracker) r1_match_tr = source_indices[box_matches[:, 0]] r2_match_tr = source_indices[box_matches[:, 1]] box_matches = box_matches[r1_match_tr != r2_match_tr] # order of items in pair doesn't matter, # remove duplicates from symmetric matrix box_matches = np.sort(box_matches, axis=-1) box_matches = np.unique(box_matches, axis=0) return box_matches
[docs]def merge_objects_from_trackers( object_trackers, pixel_vote_thr=2, cluster_iou_thr=0.75, bypass=False ): r"""Performs the consensus creation algorithm for instances from an arbitrary number of trackers (see empanada.inference.trackers). Args: object_trackers: List of empanada.inference.InstanceTracker pixel_vote_thr: Integer. Number of votes for a pixel/voxel to be in the consensus segmentation. Default 2, assumes there are 3 object trackers. cluster_iou_thr: Float. IoU threshold for merging groups of instances. Default 0.75. bypass: Bool. If True, instances that appear in just 1 of the object trackers can be included in the consensus. This will only affect the final segmentation if pixel_vote_thr < 0.5 * len(object_trackers). Default False. Returns: instances: A nested dictionary of instances. Each key is an instance_id. Values are themselves dictionaries that contain the bounding box and run length encoding of the instance ('boxes', 'starts', 'runs'). """ vol_shape = object_trackers[0].shape3d n_votes = len(object_trackers) if bypass: min_cluster_size = 1 else: # better to require majority clusters # even when not majority voxels min_cluster_size = (n_votes // 2) + 1 # default to maximal merging when # not using majority vote if pixel_vote_thr < min_cluster_size: cluster_iou_thr = 0 # unpack the instances from each tracker # into arrays for labels, bounding boxes # and voxel locations tracker_indices = [] object_labels = [] object_boxes = [] object_starts = [] object_runs = [] for tr_index, tr in enumerate(object_trackers): for instance_id, instance_attr in tr.instances.items(): tracker_indices.append(tr_index) object_labels.append(int(instance_id)) object_boxes.append(instance_attr['box']) object_starts.append(instance_attr['starts']) object_runs.append(instance_attr['runs']) # store in arrays for convenient slicing tracker_indices = np.array(tracker_indices) object_labels = np.array(object_labels) object_boxes = np.array(object_boxes) if len(object_boxes) == 0: # no instances to return return {} # screen possible matches by bounding box first box_matches = bounding_box_screening(object_boxes, tracker_indices) # create graph with nodes graph = nx.Graph() for node_id in range(len(object_labels)): graph.add_node( node_id, box=object_boxes[node_id], starts=object_starts[node_id], runs=object_runs[node_id] ) # iou as weighted edges for r1, r2 in zip(*tuple(box_matches.T)): pair_iou, inter_area = rle_iou( graph.nodes[r1]['starts'], graph.nodes[r1]['runs'], graph.nodes[r2]['starts'], graph.nodes[r2]['runs'], return_intersection=True ) # add edge for non-trivial overlaps if pair_iou > MIN_IOU or inter_area > MIN_OVERLAP: graph.add_edge(r1, r2, iou=pair_iou, overlap=inter_area) instance_id = 1 instances = {} for comp in nx.connected_components(graph): if len(comp) < min_cluster_size: continue cluster_graph = create_graph_of_clusters(graph.subgraph(comp), cluster_iou_thr) cluster_graph = merge_clusters(cluster_graph) cluster_id = 1 cluster_instances = {} for node in cluster_graph.nodes: cluster = list(cluster_graph.nodes[node]['cluster']) if len(cluster) < min_cluster_size: continue # merge boxes and coords from nodes node0 = cluster[0] merged_box = graph.nodes[node0]['box'] for node_id in cluster[1:]: merged_box = merge_boxes(merged_box, graph.nodes[node_id]['box']) # vote on indices that should belong to an object all_ranges = np.concatenate([ np.stack([graph.nodes[node_id]['starts'], graph.nodes[node_id]['starts'] + graph.nodes[node_id]['runs']], axis=1) for node_id in cluster ]) sort_idx = np.argsort(all_ranges[:, 0], kind='stable') all_ranges = all_ranges[sort_idx] voted_ranges = np.array(rle_voting(all_ranges, pixel_vote_thr)) if len(voted_ranges) > 0: cluster_instances[cluster_id] = {} cluster_instances[cluster_id]['box'] = tuple(map(lambda x: x.item(), merged_box)) cluster_instances[cluster_id]['starts'] = voted_ranges[:, 0] cluster_instances[cluster_id]['runs'] = voted_ranges[:, 1] - voted_ranges[:, 0] cluster_id += 1 # merge together instances with higher than trivial overlap for instance_attrs in merge_overlapping(cluster_instances): instances[instance_id] = instance_attrs instance_id += 1 return instances