Is sym.where equivalent to tf.cond?

#1

Hi guys. I am new to mxnet. Now I want to add some conditional control in my network. I notice that sym.where can do this work by using:

f_x = a - b
f_y = a + b

result = mx.sym.where(condition=gt, f_x, f_y)

I am curious to know, if the condition is true, will f_y=a+b be executed ? (maybe f_y is executed but the result will not be returned? )

I read the example of tensorflow, the explanation is quite clear:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
If x < y, the tf.add operation will be executed and tf.square operation will not be executed.

So, is mx.sym.where equivalent to tf.cond? Thanks.

#2

It’s clear from the documentation, it states that, if `condition == True` then `x` else `y`,

``````fx <- Symbol.where(condition,fx,fy) # if condition == True
fy <- Symbol.where(condition,fx,fy) # if condition == False
``````

same documentation for NDArray as well (it is straightforward to experiment with NDArray to test this).

Hope this helps

#3

To add to comments from @feevos, `sym.where` is very similar to `tf.where`, which are different from `tf.cond` in following ways:

1. `tf.cond` accepts a scalar for `pred` variable and returns “the result of true_fn or false_fn” dependaing ont his scalar. However `sym.where` accepts an array for `condition` and elements of `x` or `y` are returned depending on the respective element in `condition`
2. `tf.cond` only executes one of the operators depending on the condition. However sym.where requires both `x` and `y` to be evaluated before `sym.where()` is executed.

#4

Thank you @feevos @safrooze .
Is there any command in mxnet that is equivalent to tf.cond in tensorflow? Currently I don’t want both x and y to be evaluated before the conditional control.

#5

There isn’t any branching symbol in MXNet, but it offers an imperative interface with gluon/autograd that is much more flexible that Tensorflow’s symbolic branching. You should seriously consider using MXNet’s imperative API. Not only will you get a much more flexible framework for composing complex computational graphs, but also will save yourself days (any many frustrations) when debugging your network.