Proposal: new merge_dims reshaping op


I’d like to open discussion of the design of a simple reshaping layer that allows dimensions to be easily merged (by which I mean flattened together).

This would be the inverse of a layer extension that is currently under review, namely sub-setting support in reshape_like: The extension implemented by that PR makes it very easy to split dimensions that have previously been merged. This proposal takes care of the inverse case: merging several dimensions together. This is a common operation when dealing with spatial or temporal sequences which we are applying the same operation to, where it is typical to fold these extra dimensions into the batch dimension, apply the operation, and the split that batch dimension back up again.

The proposed merge_dims op would allow different dimensions to be merged. It would take as a parameter a list of which axis of the input to send each axis of the output. If multiple input axes get mapped to the same output axis, they get merged. This list must be the same length as the rank of the input.

So, for example, this table shows on the left the input shape, in the middle the target_axes parameter, and on the right the output shape that the layer would produce:

[2, 3, 5]  ->  [0, 0, 0]  ->  [30]
[2, 3, 5]  ->  [0, 0, 1]  ->  [6, 5]
[2, 3, 5]  ->  [0, 1, 1]  ->  [2, 15]
[2, 3, 5]  ->  [0, 1, 0]  ->  [10, 3]
[2, 3, 5]  ->  [1, 0, 1]  ->  [3, 10]
[2, 3, 5]  ->  [1, 0, 0]  ->  [15, 2]
[2, 3, 5]  ->  [0, 1, 2]  ->  [2, 3, 5]
[2, 3, 5]  ->  [2, 1, 0]  ->  [5, 3, 2]

Let’s explain in more detail one of these examples:

in_shape = [2,3,5]
target_axes = [0, 1, 1]
output_shape = [2, 15]

First, keep in mind that target_axes uses zero indexing, so 0 refers to the first axis of the input, 1 refers to the second axis, etc.

The initial 0 of target_axes means “send the first axis of the input to the first axis of the output”; therefore the first element of output_shape is 2.

Next, the remaining 1, 1 of target_axes means “send the second and third axis of the input to the second axis of the output”, therefore the second element of output_shape is 3*5=15.

This mechanism is quite flexible and allows for arbitrary re-ordering of the dimensions in addition to merging, which you can see in the final two rows of the table.

In Python, the exact semantics would be as follows:

def out_shape(in_shape, target_axes):
	assert(len(in_shape) == len(target_axes))
	out_shape = [1] * (1 + max(target_axes))
	assert(min(target_axes) >= 0)
	for (in_dim, target_axis) in zip(in_shape, target_axes):
		out_shape[target_axis] *= in_dim
	return out_shape

We will implement this if approved.

@ThomasDelteil @sbodenstein

Edit: unwithdrawn.


@taliesinb that seems to make sense to me. To get more eyes on it, I would suggest sending an email to the dev list copying/linking to this proposal.
@szha @piiswrong can you take a look at this?


Thanks, I’ve posted a link to the dev list!


Thanks. What are the cases that are easy for the new proposal but hard for the existing reshape operator?


@szha The cases have to do with multiple dynamic dimensions.

I apologize if you’re already familiar with anything I’m going to say. I’ll just repeat it for anyone who is not.

Imagine you have a net that is logically agnostic with respect to the dimensions of some of its input shapes. For example, an LSTM can logically operate on any length of sequence (as long as the operations that follow can too e.g. global pooling), and convolution and pooling layers can operate on e.g. images of any size as long as they are above some minimum (again as long as the operations that follow can too).

To make this concept more concrete, let’s talk about one of these dimensions that is logically allowed to vary, like the length of the LSTM’s input sequence, or the width or height of the image being convolved, as a “dynamic dimension”. Let’s use a letter to refer to it. So the input tensor to the LSTM layer might be referred to as (b, n, 10). Here b refers to the batch size, and n refers to the sequence length. Remember, at this point we’re operating at the “logical level”, we’re talking about layers that a user might create in a high-level framework, rather than the exact MXNet ops that would be used to implement them.

If you start with one of these logical nets, and then produce an MXSymbol from it, there is a particular property that I call “shape polymorphism” that it is very useful for this resulting MXSymbol to have. Specificially, this property that you can simply call MXSymbolInferShape, providing concrete values for all of the dynamic dimensions, and a valid MXSymbol will result.

Shape polymorphism is a highly desirable property because it gives a very fast way to take an initial dimension-agnostic MXSymbol and specialize it to produce a MXSymbol that will operate on a specific batch size, sequence length, image size, etc. All you need is to feed in the concrete values for the initial dynamic dimensions, and rely on shape inference to propagate these dynamic dimensions through the network.

Shape polymorphism also makes nets very straightforward to deploy in production, because the user of the net just needs to create an initial “master” MXSymbol from a single JSON specification and then specialize it as needed by calling MXSymbolShapeInfer by giving it the dimensions of the input tensors it will operate on (a very simple recipe).

If you don’t have the property of shape polymorphism, it is still possible to make new nets, but it is much more complicated, and involves templatizing the JSON, for example, which is slower, more error prone, harder to implement, harder to deploy and maintain, and harder to communicate.

Ok, I’ve described the background to this issue. Apologies if it was obvious, my experience is that not everyone encounters these issues of shape polymorphism in the particular work they do.

Now, specifically, what this new reshape operator does better than the previous reshape layer is to allow flattening of tensors that involve more than two dynamic dimensions.

Let’s first take an example we can handle with the existing reshape operator, which is something like (b, n, 3) being flattened to (b * n, 3). This is straightforward with a reshape op having shape (-1, 3).

But, if you have the shape (b, n, m, 3), and you wish to flatten this to (b * n, m, 3), then you cannot use the same technique to do this. If you use the -1 to combine b and n, you cannot use a second -1 to propogate m, and the 0 code refers to wrong position. The -3, -2 code sequence will work in this case but will not allow you to flatten 3 dimensions together simultaneously, so you’re only delaying the problem.

We could add more features to reshape, for example, we could add a code that flattens together more than two dimensions at a time, but it seems more reasonable to me to identify the whole concept of merging dimensions as an operation worthy of its own op rather than being shoe-horned into an existing operator via yet more obscure negative shape codes.

What do others think? Should we introduce new codes into reshape, which is already complex and hard to understand? Or should we introduce a new op that handles merging/permutation of axes and possibly even deprecate the -3 code in reshape?


@szha sorry to bug you, but any progress on this? our window to work on mxnet code for the release of Mathematica 12 is closing, we’d like to know whether we can implement this soon.


@ThomasDelteil sorry to be impatient. unfortunately we need to synchronize our contributions to MXNet with changes to our framework, and so this is actually holding up other features in Mathematica 12.


So the new operator’s additional value is on merging multiple, potentially non-consecutive axes. Are there real use cases that require merging non-consecutive axes? If yes, then the new op is required to fulfill such use case. If not, and it’s just about merging multiple consecutive axes, then my personal preference would be to extend current reshape. Deprecating -3 is not an option until 2.0, and even then it’s likely not desirable.


@szha it’s not typical to merge non-consecutive axes, and i’m skeptical it would ever happen. it is possible to merge multiple consecutive axes by applying reshape repeatedly with the -3 code. a merge_dims layer would be a cleaner way of merging multiple dimensions, and could take over completely from the -3 code – but if we’re not going to remove the -3 code, there’s indeed not much reason for this layer other than cleaner design.


I had the same thought. In that case, I hope my delay in reply didn’t cause real blocking issue as there clearly was workaround in the use case you had in mind. Personally I’d prefer not to add things just for nicety, but feel free to propose a patch and see how others feel.


Yup, sorry, I gave the wrong impression. it’s that if we were going to do this we would need to decide soon otherwise it would be too late for our release cycle.