[C++ Symbol API] min function


#1

Why is there no function for minimum value between two symbols? Minimum operator currently is used to find the minimum value in an array.

Use case: reinforcement learning with Loss Function involving minimum. I will need to find the gradient of such a function.

E.g. the code below will not compile:

Symbol a = Symbol::Variable("a");
Symbol b = Symbol::Variable("b");
Symbol c = min(a*a, b);

Update: Got around this as below (dumb solution tho)

Symbol min(Symbol a, Symbol b) {
    return broadcast_lesser(a,b)*a + broadcast_lesser(b,a)*b;
}

#2

Hi @han-so1omon,

Could you also concatenate the two symbols and then find the minimum from the concatenated symbol?


#3

I think so. I could concatenate to add an extra dimension and then find the minimum element along that dimension, which I assume performs a reduction. I don’t know what the benefits of either are just off-hand.


#4

I don’t know if this is also available in the C++ API, but in Python, you can also use where(a < b, a, b).