import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from empanada.inference.matcher import fast_matcher
[docs]class EMAMeter:
r"""Computes and stores an exponential moving average and current value"""
def __init__(self, momentum=0.98):
self.mom = momentum
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum = (self.sum * self.mom) + (val * (1 - self.mom))
self.count += 1
self.avg = self.sum / (1 - self.mom ** (self.count))
[docs]class AverageMeter:
r"""Computes and stores a moving average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum = self.sum + self.val
self.count += 1
self.avg = self.sum / self.count
class _BaseMetric:
r"""Default Metric class"""
def __init__(self, meter, labels):
self.meters = {l: meter() for l in labels}
self.labels = labels
def update(self, value_dict):
for l,v in value_dict.items():
self.meters[l].update(v)
def reset(self):
for l in self.labels:
self.meters[l].reset()
def average(self):
return {l: meter.avg for l,meter in self.meters.items()}
[docs]class IoU(_BaseMetric):
r"""Computes the IoU between output and target.
Input is expected to be a dictionary for each.
Args:
meter: EMAMeter or AverageMeter to track.
labels: List of all semantic/instance labels to compare.
output_key: Key in the output dictionary to compare.
target_key: Key in the target dictionary to compare.
"""
def __init__(
self,
meter,
labels,
output_key='sem_logits',
target_key='sem',
**kwargs
):
super().__init__(meter, labels)
self.output_key = output_key
self.target_key = target_key
def calculate(self, output, target):
# only require the semantic segmentation
output = output[self.output_key]
target = target[self.target_key]
# make target the same shape as output by unsqueezing
# the channel dimension, if needed
if target.ndim == output.ndim - 1:
target = target.unsqueeze(1)
# get the number of classes from the output channels
n_classes = output.size(1)
# get reshape size based on number of dimensions
# can exclude first 2 dims, which are always batch and channel
empty_dims = (1,) * (target.ndim - 2)
if n_classes > 1:
# one-hot encode the target (B, 1, H, W) --> (B, N, H, W)
k = torch.arange(0, n_classes).view(1, n_classes, *empty_dims).to(target.device)
target = (target == k)
# softmax the output
output = nn.Softmax(dim=1)(output)
else:
# sigmoid the output
output = (torch.sigmoid(output) > 0.5).long()
# cast target to the correct type for operations
target = target.type(output.dtype)
# multiply the tensors, everything that is still as 1 is part of the intersection
# (N,)
dims = (0,) + tuple(range(2, target.ndimension()))
intersect = torch.sum(output * target, dims)
# compute the union, (N,)
union = torch.sum(output + target, dims) - intersect
# avoid division errors by adding a small epsilon
iou = (intersect + 1e-7) / (union + 1e-7)
if n_classes == 1:
return {self.labels[0]: iou.item()}
else:
return {l: iou[l] for l in self.labels}
[docs]class PQ(_BaseMetric):
r"""Computes the panoptic quality between output and target.
Input is expected to be a dictionary for each.
Args:
meter: EMAMeter or AverageMeter to track
labels: List of all semantic/instance labels to compare
label_divisor: Integer. Label divisor used during postprocessing.
output_key: Key in the output dictionary to compare.
target_key: Key in the target dictionary to compare.
"""
def __init__(
self,
meter,
labels,
label_divisor,
output_key='pan_seg',
target_key='pan_seg',
**kwargs
):
super().__init__(meter, labels)
self.label_divisor = label_divisor
self.output_key = output_key
self.target_key = target_key
def _to_class_seg(self, pan_seg, label):
instance_seg = np.copy(pan_seg) # copy for safety
min_id = label * self.label_divisor
max_id = min_id + self.label_divisor
# zero all objects/semantic segs outside of instance_id range
outside_mask = np.logical_or(instance_seg < min_id, instance_seg >= max_id)
instance_seg[outside_mask] = 0
return instance_seg
def calculate(self, output, target):
# convert tensors to numpy
output = output[self.output_key].squeeze().long().cpu().numpy()
target = target[self.target_key].squeeze().long().cpu().numpy()
# compute the panoptic quality, per class
per_class_results = {}
for label in self.labels:
pred_class_seg = self._to_class_seg(output, label)
tgt_class_seg = self._to_class_seg(target, label)
# match the segmentations
matched_labels, all_labels, matched_ious = \
fast_matcher(tgt_class_seg, pred_class_seg, iou_thr=0.5)
tp = len(matched_labels[0])
fn = len(np.setdiff1d(all_labels[0], matched_labels[0]))
fp = len(np.setdiff1d(all_labels[1], matched_labels[1]))
if tp + fp + fn == 0:
# by convention, PQ is 1 for empty masks
per_class_results[label] = 1.
continue
sq = matched_ious.sum() / (tp + 1e-5)
rq = tp / (tp + 0.5 * fp + 0.5 * fn)
per_class_results[label] = sq * rq
return per_class_results
[docs]class F1(_BaseMetric):
r"""Computes the F1 between output and target instance segmentation
classes. Input is expected to be a dictionary for each.
Args:
meter: EMAMeter or AverageMeter to track
labels: List of all instance labels to compare
label_divisor: Integer. Label divisor used during postprocessing.
iou_thr: Float, IoU threshold at which to determine TP, FP, FN detections.
output_key: Key in the output dictionary to compare.
target_key: Key in the target dictionary to compare.
"""
def __init__(
self,
meter,
labels,
label_divisor,
iou_thr=0.5,
output_key='pan_seg',
target_key='pan_seg',
**kwargs
):
super().__init__(meter, labels)
self.label_divisor = label_divisor
self.iou_thr = iou_thr
self.output_key = output_key
self.target_key = target_key
def _to_class_seg(self, pan_seg, label):
instance_seg = np.copy(pan_seg) # copy for safety
min_id = label * self.label_divisor
max_id = min_id + self.label_divisor
# zero all objects/semantic segs outside of instance_id range
outside_mask = np.logical_or(instance_seg < min_id, instance_seg >= max_id)
instance_seg[outside_mask] = 0
return instance_seg
def calculate(self, output, target):
# convert tensors to numpy
output = output[self.output_key].squeeze().long().cpu().numpy()
target = target[self.target_key].squeeze().long().cpu().numpy()
# compute the panoptic quality, per class
per_class_results = {}
for label in self.labels:
pred_class_seg = self._to_class_seg(output, label)
tgt_class_seg = self._to_class_seg(target, label)
# match the segmentations
matched_labels, all_labels, matched_ious = \
fast_matcher(tgt_class_seg, pred_class_seg, iou_thr=self.iou_thr)
tp = len(matched_labels[0])
fn = len(np.setdiff1d(all_labels[0], matched_labels[0]))
fp = len(np.setdiff1d(all_labels[1], matched_labels[1]))
if tp + fp + fn == 0:
# by convention, F1 is 1 for empty masks
per_class_results[label] = 1.
else:
f1 = tp / (tp + 0.5 * fn + 0.5 * fp)
per_class_results[label] = f1
return per_class_results
[docs]class ComposeMetrics:
r"""Bundles multiple metrics together for easy
evaluation, printing and logging during training.
Args:
metrics_dict: Dictionary, keys are the names of metrics and values are
the _BaseMetric class than records/calculate that metric.
class_names: Dictionary, keys are class_ids and values are names.
reset_on_print: Bool. If True, the history of each metric is wiped
after results are printed.
"""
def __init__(
self,
metrics_dict,
class_names,
reset_on_print=True
):
self.metrics_dict = metrics_dict
self.class_names = class_names
self.reset_on_print = reset_on_print
self.history = {}
def evaluate(self, output, target):
# calculate all the metrics in the dict
for metric in self.metrics_dict.values():
value = metric.calculate(output, target)
metric.update(value)
def display(self):
print_names = []
print_values = []
for metric_name, metric in self.metrics_dict.items():
avg_values = metric.average()
for l, v in avg_values.items():
value_name = self.class_names[l]
print_values.append(v)
print_names.append(f'{value_name}_{metric_name}')
if self.reset_on_print:
metric.reset()
# print out the metrics
for name, value in zip(print_names, print_values):
if name not in self.history:
self.history[name] = [value]
else:
self.history[name].append(value)
print(name, value)