Implementing lovasz_loss for keras-mxnet


#1

Hi,

I’ve been trying to port an implementation of the lovasz_loss, but I’ve run into a few issues.

Here is the original:

# code download from: https://github.com/bermanmaxim/LovaszSoftmax
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    gts = tf.reduce_sum(gt_sorted)
    intersection = gts - tf.cumsum(gt_sorted)
    union = gts + tf.cumsum(1. - gt_sorted)
    jaccard = 1. - intersection / union
    jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
    return jaccard


# --------------------------- BINARY LOSSES ---------------------------

def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        def treat_image(log_lab):
            log, lab = log_lab
            log, lab = tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
            log, lab = flatten_binary_scores(log, lab, ignore)
            return lovasz_hinge_flat(log, lab)
        losses = tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
        loss = tf.reduce_mean(losses)
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """

    def compute_loss():
        labelsf = tf.cast(labels, logits.dtype)
        signs = 2. * labelsf - 1.
        errors = 1. - logits * tf.stop_gradient(signs)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
        gt_sorted = tf.gather(labelsf, perm)
        grad = lovasz_grad(gt_sorted)
        loss = tf.tensordot(tf.nn.relu(errors_sorted), tf.stop_gradient(grad), 1, name="loss_non_void")
        return loss

    # deal with the void prediction case (only void pixels)
    loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
                   lambda: tf.reduce_sum(logits) * 0.,
                   compute_loss,
                   strict=True,
                   name="loss"
                   )
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = tf.reshape(scores, (-1,))
    labels = tf.reshape(labels, (-1,))
    if ignore is None:
        return scores, labels
    valid = tf.not_equal(labels, ignore)
    vscores = tf.boolean_mask(scores, valid, name='valid_scores')
    vlabels = tf.boolean_mask(labels, valid, name='valid_labels')
    return vscores, vlabels

def lovasz_loss(y_true, y_pred):
    y_true, y_pred = K.cast(K.squeeze(y_true, -1), 'int32'), K.cast(K.squeeze(y_pred, -1), 'float32')
    #logits = K.log(y_pred / (1. - y_pred))
    logits = y_pred #Jiaxin
    loss = lovasz_hinge(logits, y_true, per_image = True, ignore = None)
    return loss

And here is what I’ve added to keras/backend/mxnet_backend.py:

def lovasz_grad(gt_sorted):
      """
      Computes gradient of the Lovasz extension w.r.t sorted errors
      See Alg. 1 in paper
      """
      gts = mx.sym.sum(gt_sorted) #tf.reduce_sum(gt_sorted)
      intersection = gts - np.cumsum(gt_sorted) #tf.cumsum(gt_sorted)
      union = gts + np.cumsum(1. - gt_sorted) #tf.cumsum(1. - gt_sorted)
      jaccard = 1. - intersection / union
      jaccard = mx.sym.concat(jaccard[0:1], jaccard[1:] - jaccard[:-1], axis=0) #tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
      return jaccard

def lovasz_hinge_flat(logits, labels):
     """
     Binary Lovasz hinge loss
        logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
        labels: [P] Tensor, binary ground truth labels (0 or 1)
        ignore: label to ignore
     """

  def compute_loss():
      labelsf = mx.sym.cast(labels, logits.dtype) #tf.cast(labels, logits.dtype)
      signs = 2. * labelsf - 1.
      errors = 1. - logits * mx.sym.stop_gradient(signs) #tf.stop_gradient(signs)
      errors_sorted, perm = mx.sym.topk(errors, k=, name="descending_sort")
      #tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
      gt_sorted = mx.sym.gather_nd(labelsf, perm) #tf.gather(labelsf, perm)
      grad = lovasz_grad(gt_sorted)
      loss = mx.sym.dot(mx.sym.relu(errors_sorted), mx.sym.stop_gradient(grad), 1, name="loss_non_void")
      #tf.tensordot(tf.nn.relu(errors_sorted), tf.stop_gradient(grad), 1, name="loss_non_void")
      return loss

  # deal with the void prediction case (only void pixels)
  """
  loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
                 lambda: tf.reduce_sum(logits) * 0.,
                 compute_loss,
                 strict=True,
                 name="loss"
                 )
  """
  loss = mx.contrib.cond(mx.sym.broadcast_equal, logits.infer_shape(), 0),
          lambda: mx.sym.sum(logits) * 0.,
          compute_loss,
          name="loss"
          )
 return loss

 
 def flatten_binary_scores(scores, labels, ignore=None):
      """
      Flattens predictions in the batch (binary case)
      Remove labels equal to 'ignore'
      """
      scores = mx.sym.reshape(scores, (-1,)) #tf.reshape(scores, (-1,))
      labels = mx.sym.reshape(labels, (-1,)) #tf.reshape(labels, (-1,))
      if ignore is None:
          return scores, labels
      valid = my.sym.broadcast_not_equal(labels, ignore) #tf.not_equal(labels, ignore)
      vscores = tf.boolean_mask(scores, valid, name='valid_scores')
      vlabels = tf.boolean_mask(labels, valid, name='valid_labels')
      return vscores, vlabels
 
 
 @keras_mxnet_symbol
 def lovasz_hinge(logits, labels, per_image=True, ignore=None):
      """
      Binary Lovasz hinge loss
        logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
        labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
        per_image: compute the loss per image instead of per batch
        ignore: void class id
      """
      if per_image:
          def _step(log_lab):
              log, lab = log_lab
              log, lab = mx.sym.expand_dims(log, 0), mx.sym.expand_dims(lab, 0) #tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
              log, lab = flatten_binary_scores(log, lab, ignore)
              return lovasz_hinge_flat(log, lab)
          losses = mx.sym.foreach(_step, (logits, labels), []) #tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
          loss = mx.sym.mean(losses) #tf.reduce_mean(losses)
      else:
          loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
      return KerasSymbol(loss)

Here are the issues I’ve encountered:

If anyone has suggestions on these or spots obvious errors in my code. Please let me know!

@roywei @thomelane


#2

As discussed off line, for cumsum the current workaround is to use numpy. However, for dynamic shape, keras-mxnet requires support in mxnet symbol interface, which may come at a later time. We will add this support once it’s out, it will be a major improvement on keras-mxnet especially on RNN use cases.

Again thanks for trying out keras-mxnet!


#3

Thanks again for looking into it!