Source code for empanada.losses

"""
Confidence weighting of loss that's efficient.
"""

import torch
import torch.nn as nn
from empanada.models.point_rend import point_sample

__all__ = [
    'PanopticLoss',
    'BCLoss'
]

class BootstrapCE(nn.Module):
    r"""Standard (binary) cross-entropy loss where only the top
    k percent of largest loss values are averaged.

    Args:
        top_k_percent_pixels: Float, fraction of largest loss values
            to average. Default 0.2

    """
    def __init__(self, top_k_percent_pixels=0.2):
        super(BootstrapCE, self).__init__()
        self.top_k_percent_pixels = top_k_percent_pixels
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, labels):
        if logits.size(1) == 1:
            # add channel dim for BCE
            # (N, H, W) -> (N, 1, H, W)
            labels = labels.unsqueeze(1)
            pixel_losses = self.bce(logits, labels)
        else:
            pixel_losses = self.ce(logits, labels)

        pixel_losses = pixel_losses.contiguous().view(-1)

        if self.top_k_percent_pixels == 1.0:
            return pixel_losses.mean()

        top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel())
        pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels)

        return pixel_losses.mean()

class HeatmapMSE(nn.Module):
    r"""
    Mean squared error (MSE) loss for instance center heatmaps
    """
    def __init__(self):
        super(HeatmapMSE, self).__init__()
        self.mse = nn.MSELoss()

    def forward(self, output, target):
        return self.mse(output, target)

class OffsetL1(nn.Module):
    r"""
    L1 loss for instance center offsets. Loss is only calculated
    within the confines of the semantic segmentation.
    """
    def __init__(self):
        super(OffsetL1, self).__init__()
        self.l1 = nn.L1Loss(reduction='none')

    def forward(self, output, target, offset_weights):
        l1 = self.l1(output, target) * offset_weights

        weight_sum = offset_weights.sum()
        if weight_sum == 0:
            return l1.sum() * 0
        else:
            return l1.sum() / weight_sum

class PointRendLoss(nn.Module):
    r"""Standard (binary) cross-entropy between logits at
    points sampled by the point rend module.
    """
    def __init__(self):
        super(PointRendLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss(reduction='mean')
        self.ce = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, point_logits, point_coords, labels):
        # sample the labels at the given coordinates
        point_labels = point_sample(
            labels.unsqueeze(1).float(), point_coords,
            mode="nearest", align_corners=False
        )

        if point_logits.size(1) == 1:
            point_losses = self.bce(point_logits, point_labels)
        else:
            point_labels = point_labels.squeeze(1).long()
            point_losses = self.ce(point_logits, point_labels)

        return point_losses

[docs]class PanopticLoss(nn.Module): r"""Defines the overall panoptic loss function which combines semantic segmentation, instance centers and offsets. Args: ce_weight: Float, weight to apply to the semantic segmentation loss. mse_weight: Float, weight to apply to the centers heatmap loss. l1_weight: Float, weight to apply to the center offsets loss. pr_weight: Float, weight to apply to the point rend semantic segmentation loss. Only applies if using a Point Rend enabled model. top_k_percent: Float, fraction of largest semantic segmentation loss values to consider in BootstrapCE. """ def __init__( self, ce_weight=1, mse_weight=200, l1_weight=0.01, pr_weight=1, top_k_percent=0.2 ): super(PanopticLoss, self).__init__() self.mse_loss = HeatmapMSE() self.l1_loss = OffsetL1() self.ce_loss = BootstrapCE(top_k_percent) self.pr_loss = PointRendLoss() self.ce_weight = ce_weight self.mse_weight = mse_weight self.l1_weight = l1_weight self.pr_weight = pr_weight def forward(self, output, target): mse = self.mse_loss(output['ctr_hmp'], target['ctr_hmp']) ce = self.ce_loss(output['sem_logits'], target['sem']) # only evaluate loss inside of ground truth segmentation offset_weights = (target['sem'] > 0).unsqueeze(1) l1 = self.l1_loss(output['offsets'], target['offsets'], offset_weights) aux_loss = {'ce': ce.item(), 'l1': l1.item(), 'mse': mse.item()} total_loss = self.ce_weight * ce + self.mse_weight * mse + self.l1_weight * l1 if 'sem_points' in output: pr_ce = self.pr_loss(output['sem_points'], output['point_coords'], target['sem']) aux_loss['pointrend_ce'] = pr_ce.item() total_loss += self.pr_weight * pr_ce aux_loss['total_loss'] = total_loss.item() return total_loss, aux_loss
[docs]class BCLoss(nn.Module): r"""Defines the overall loss for a boundary contour prediction model. Args: pr_weight: Float, weight to apply to the point rend semantic segmentation loss. Only applies if using a Point Rend enabled model. top_k_percent: Float, fraction of largest semantic segmentation loss values to consider in BootstrapCE. """ def __init__( self, pr_weight=1, top_k_percent=0.15 ): super(BCLoss, self).__init__() self.ce_loss = BootstrapCE(top_k_percent) self.pr_loss = PointRendLoss() self.pr_weight = pr_weight def forward(self, output, target): # mask losses sem_ce = self.ce_loss(output['sem_logits'], target['sem']) cnt_ce = self.ce_loss(output['cnt_logits'], target['cnt']) aux_loss = {'sem_ce': sem_ce.item(), 'cnt_ce': cnt_ce.item()} total_loss = sem_ce + cnt_ce # add the point rend losses from both if 'sem_points' in output: sem_pr_ce = self.pr_loss(output['sem_points'], output['sem_point_coords'], target['sem']) cnt_pr_ce = self.pr_loss(output['cnt_points'], output['cnt_point_coords'], target['cnt']) aux_loss['sem_pr_ce'] = sem_pr_ce.item() aux_loss['cnt_pr_ce'] = cnt_pr_ce.item() total_loss += self.pr_weight * (sem_pr_ce + cnt_pr_ce) aux_loss['total_loss'] = total_loss.item() return total_loss, aux_loss