Hi there,
I’m wondering how to multiply two tensors along the batch axe. It should work as the following naive implementation:
import mxnet.ndarray as nd
A = nd.ones((2, 3, 4)) # CTN layout
B = nd.ones((2, 3, 4))
C = nd.zeros((3, 3, 4))
for batch in range(A.shape[2]):
C[:, :, batch] = nd.dot(A[:, :, batch].transpose(), B[:, :, batch])
print(C)
In TF, I can use tf.matmul(A, B, transpose_b=True)
. Is there any equivalent to that?
Thanks.