Problem with symbol shapes and gemm

Hi all,

I am trying to develop a simple network using the mx.symbol.linalg_gemm2 function. However I cannot get the training of the model to work due to the impossibility of incorrectly inferring the symbol shapes. I am using the R API.
The code follows:

NFEAT = 795
nDest 780
batchSz = 150

DL = customArrayIter(X.model, data.shape=c(NFEAT, 150), label=Y.model, batch.size=1)

X = mx.symbol.Variable('data')
A = mx.symbol.Variable('A')
B = mx.symbol.Variable('B')

CC = mx.symbol.linalg_gemm2(X, A, name='CC')
Yhat = mx.symbol.linalg_gemm2(B, CC, name='Yhat')
out = mx.symbol.SoftmaxActivation(Yhat, name='out')
loss = mx.symbol.LinearRegressionOutput(Yhat, name='loss')

To note that using the function mx.symbol.infer.shape as follows:

shps=mx.symbol.infer.shape(loss, data=c(NFEAT, 150), A=c(nDest, NFEAT), B=c(150, 1))

The shapes appear to be correctly inferred.

Can someone advise me on hoe to specify the shapes in the mx.mdoel.FeedForward.create function?
Thanks a lot!

I would really need help on this? Anyone with experience in using these functions?