I’m training a neural network model with mxnet. The input is a sparse one-hot vector. It has dimension of 1 million, but only tens of them are 1. The input is fullyconnected to a hidden layer with 200 nodes. I found it is very slow when training. Is there any way to speed up the fullyconnected calculation as the input is pretty sparse?
You can encode the data in CSR format and replace FC with sparse.dot and broadcast_add. You can see the linear classification example here: https://github.com/apache/incubator-mxnet/tree/master/example/sparse
@eric-haibin-lin I tried encoding the data in CSR format and replacing ndarray.FullyConnected with nd.sparse.dot, but I found its efficiency became even worse. My code is as following:
original: time cost 0:00:00.365667
#out1 = mx.nd.FullyConnected(features, self.w1.data(ctx), self.b1.data(ctx), num_hidden=self.num_hidden) #act1 = mx.nd.Activation(out1, act_type=‘relu’)
new: time cost 0:00:00.495941
out1 = mx.nd.sparse.dot(features, self.w1.data(ctx))
act1 = mx.nd.broadcast_add(out1, self.b1.data(ctx))
where w1 is weight matrix, and b1 is bias matrix. Features is the input, which is a 200 X 1000000 matrix with about 2000 non-zero values. And I have encoded it in CSR format.
Did you call
act1.wait_to_read() to make sure the operation is completed?
I am a bit confused - what is the shape of
w1 you are using?
FullyConnected is calculating
w1_transpose, which is different from
dot(feature,w1). Are you getting consistent result here?
The following code works for me:
import mxnet as mx import scipy.sparse as spsp csr = spsp.rand(200, 1000000, format='csr', density=0.00001) x_sparse = mx.nd.sparse.csr_matrix(csr) w = mx.nd.ones((1000000, 100)) import time mx.nd.waitall(); a = time.time(); y = mx.nd.sparse.dot(x_sparse, w); y.wait_to_read(); b = time.time(); print(b-a); # 0.00143098831177 w_t = w.T x_dense = x_sparse.tostype('default') mx.nd.waitall(); c = time.time(); y2 = mx.nd.FullyConnected(x_dense, w_t, no_bias=True, num_hidden=100); y2.wait_to_read(); d = time.time(); print(d - c); # 0.451608896255