Rnn gradient explodes


#1

Problem solved with gradient clip

The rnn class is borrowed from this tutorial

class RNN(gluon.Block):
    def __init__(self, mode, seed, vocab_size, num_embed, num_hidden,
                 num_layers, dropout, **kwargs):
        super(RNN, self).__init__(**kwargs)
        if seed:
            mx.random.seed(seed)

        with self.name_scope():
            # self.drop = nn.Dropout(dropout)
            # self.encoder = nn.Embedding(vocab_size, num_embed,
            #                             weight_initializer = mx.init.Uniform(0.1))
            if mode == 'rnn_relu':
                self.rnn = rnn.RNN(num_hidden, num_layers, activation='relu', dropout=dropout,
                                   input_size=num_embed)
            elif mode == 'rnn_tanh':
                self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            elif mode == 'lstm':
                self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
                                    input_size=num_embed)
            elif mode == 'gru':
                self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            else:
                raise ValueError("Invalid mode %s. Options are rnn_relu, "
                                 "rnn_tanh, lstm, and gru"%mode)
            self.decoder = nn.Dense(vocab_size, in_units = num_hidden)
            self.num_hidden = num_hidden

    def forward(self, inputs, hidden):
        with inputs.context:
            output, hidden = self.rnn(inputs, hidden)
            decoded = self.decoder(output.reshape((-1, self.num_hidden)))
            return decoded, hidden

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)
loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False, batch_axis=1)

Switch to linux meets thet same issue.


#3

@ShootingSpace you solved your problem using gradient clipping right?