Simple RNN - convLSTM example?

Dear all,

I am doing my first steps in RNN models (in fact I am interested in convolutional 2D RNN/LSTM models). My goal is to use them in a way similar to conditional random fields (CRF), for image segmentation refinement. So in the RNN case I am interested in the regression / time-series forecasting perspective. Could please someone give me a simple example of a single forward function for a gluon.contrib.Conv2DRNNCell/Conv2DLSTMCell ?

Something like:

nbatch = 10
nfilters = 32
shape = [nbatch, nfilters, 64,64]
xx = nd.random_uniform(shape = shape)

net = gluon.contrib.rnn.Conv2DLSTMCell( some_params) # I am not completely sure about the params. 
states = ... # ? What goes in here? 
temp = net(xx, states) # This is where I am getting 


My ultimate goal is to create a custom Conv2DLSTM layer to use for fixed length iterations.

I’ve been through various input sources on the web about RNNs, most notably this and this but I am having a bit of trouble understanding the particulars of the implementation. E.g. from my understanding, in the some_params argument of Conv2DLSTM input_shape should be (nbatch,nfilters,64,64) but according to documentation this should be (nfilters,64,64)? And this messes up my understanding of dimensions of states (which should be - for a convolutional layer - the same as xx? But I get errors when trying to run a simple example.

Any pointers to documentation / examples for RNN/convRNNs, anything will be extremely appreciated. Thank you very much.

getting there:

import mxnet as mx
from mxnet import gluon, nd, autograd
nbatch = 10
nfilters = 12
shape = [nbatch,nfilters,64,64]
xx = nd.random_uniform(shape=shape)
# single layer
net  = gluon.contrib.rnn.Conv2DRNNCell(input_shape=[nfilters,64,64],
                                      hidden_channels=nfilters,
                                      i2h_kernel=(3,3), 
                                      h2h_kernel =(3,3))
init_state = net.begin_state(batch_size=nbatch)
state = init_state
with autograd.record():
    out, state = net(xx,state)
print (out.shape)
(10, 12, 62, 62)
print (state[0].shape)
(10, 12, 62, 62)