Is sym.where equivalent to tf.cond?


#1

Hi guys. I am new to mxnet. Now I want to add some conditional control in my network. I notice that sym.where can do this work by using:

f_x = a - b
f_y = a + b

result = mx.sym.where(condition=gt, f_x, f_y)

I am curious to know, if the condition is true, will f_y=a+b be executed ? (maybe f_y is executed but the result will not be returned? )

I read the example of tensorflow, the explanation is quite clear:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
If x < y, the tf.add operation will be executed and tf.square operation will not be executed.

So, is mx.sym.where equivalent to tf.cond? Thanks.


#2

It’s clear from the documentation, it states that, if condition == True then x else y,

fx <- Symbol.where(condition,fx,fy) # if condition == True
fy <- Symbol.where(condition,fx,fy) # if condition == False

same documentation for NDArray as well (it is straightforward to experiment with NDArray to test this).

Hope this helps


#3

To add to comments from @feevos, sym.where is very similar to tf.where, which are different from tf.cond in following ways:

  1. tf.cond accepts a scalar for pred variable and returns “the result of true_fn or false_fn” dependaing ont his scalar. However sym.where accepts an array for condition and elements of x or y are returned depending on the respective element in condition
  2. tf.cond only executes one of the operators depending on the condition. However sym.where requires both x and y to be evaluated before sym.where() is executed.

#4

Thank you @feevos @safrooze .
Is there any command in mxnet that is equivalent to tf.cond in tensorflow? Currently I don’t want both x and y to be evaluated before the conditional control.


#5

There isn’t any branching symbol in MXNet, but it offers an imperative interface with gluon/autograd that is much more flexible that Tensorflow’s symbolic branching. You should seriously consider using MXNet’s imperative API. Not only will you get a much more flexible framework for composing complex computational graphs, but also will save yourself days (any many frustrations) when debugging your network.