How to remember the state of LSTM when using gluon?


I’m very sorry to post such question, the solution is quite easy:


   out: output tensor with shape (sequence_length, batch_size, num_hidden) when layout is “TNC”. If bidirectional is True, output shape will instead be (sequence_length, batch_size, 2*num_hidden)
   out_states: a list of two output recurrent state tensors with the same shape as in states. If states is None out_states will not be returned.

I try to implement torch-rnn by using MXNet, but I find there are no “remember state” nor “reset” function I could use when using gluon:
if a “remember state” param is possible,the program will be very easy to write:

  	for i in range(self.num_layers):
  		prev_dim = H if i > 1 else D
  		if self.model_type == 'rnn' :
  			rnn = rnn.GRU(prev_dim, H,dropout=self.dropout)
  		elif self.model_type == 'lstm' :
  			rnn = rnn.LSTM(prev_dim, H,dropout=self.dropout)
  		else : rnn = rnn.RNN(prev_dim, H,dropout=self.dropout)
  		if self.batchnorm == 1 :,V))

But I do not know how to using that paraments(as far as i know, such paraments does not exist.) It seems that the only solution is using mx.sym, unroll LSTM manually, and manually reset the states of LSTM.

But I wonder are there some easy solutions?


Hi @Neutron,

When working with Gluon’s recurrent layers, you have an explicit reference to the state so you can choose to ‘remember’ or ‘reset’ the RNNs depending on how you use it. At the start you can use begin_state to generate an initial hidden state (essentially a ‘reset’), and then the RNN layer will return the outputs and the hidden state, which you can optionally feed as the initial hidden state to the next RNN (essentially a ‘remember’).

See the example below for generating a hidden state, and getting the hidden state after the forward pass of the RNN layer has completed. One thing to note is that you must pass an initial hidden state to the RNN layer if you want the final hidden state returned.

import mxnet as mx

sequence_length = 4
batch_size = 5
channels = 3
hid_layers = 1
hid_units = 6

rnn1 = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')
rnn2 = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')
rnn3 = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')


inputs1 = mx.nd.random.uniform(shape=(sequence_length, batch_size, channels))
hid_init1 = rnn1.begin_state(batch_size)
outputs1, hid_states1 = rnn1(inputs1, hid_init1)

# 'remember' state (from rnn1)
inputs2 = mx.nd.random.uniform(shape=(sequence_length, batch_size, channels))
outputs2, hid_states2 = rnn2(inputs2, hid_states1)

# 'reset' state
inputs3 = mx.nd.random.uniform(shape=(sequence_length, batch_size, channels))
hid_init3 = rnn3.begin_state(batch_size)
outputs3, hid_states = rnn3(inputs3, hid_init3)