Mxnet function in symbol


#1

If I want to extract a few lines of a matrix(symbol), how should I do?


#2

Hi @dhuqqq,

Use slice for this.

You can use this to slice both rows and columns of a matrix. Also, use None if you want to slice to the end.

See below for an example of how to slice the 2nd and 3rd rows of a matrix.

import mxnet as mx

input_data = mx.random.uniform(shape=(3,3))
input_data

[[ 0.54881352 0.59284461 0.71518934]
[ 0.84426576 0.60276335 0.85794562]
[ 0.54488319 0.84725171 0.42365479]]
<NDArray 3x3 @cpu(0)>

a_sym = mx.sym.Variable('input')
b_sym = a_sym.slice(begin=(1,None), end=(3,None))

executor = b_sym.bind(mx.cpu(), {'input': input_data})
output = executor.forward()
output[0].asnumpy()

[[ 0.84426576 0.60276335 0.85794562]
[ 0.54488319 0.84725171 0.42365479]]
<NDArray 2x3 @cpu(0)>