Proposal: new merge_dims reshaping op

@szha The cases have to do with multiple dynamic dimensions.

I apologize if you’re already familiar with anything I’m going to say. I’ll just repeat it for anyone who is not.

Imagine you have a net that is logically agnostic with respect to the dimensions of some of its input shapes. For example, an LSTM can logically operate on any length of sequence (as long as the operations that follow can too e.g. global pooling), and convolution and pooling layers can operate on e.g. images of any size as long as they are above some minimum (again as long as the operations that follow can too).

To make this concept more concrete, let’s talk about one of these dimensions that is logically allowed to vary, like the length of the LSTM’s input sequence, or the width or height of the image being convolved, as a “dynamic dimension”. Let’s use a letter to refer to it. So the input tensor to the LSTM layer might be referred to as (b, n, 10). Here b refers to the batch size, and n refers to the sequence length. Remember, at this point we’re operating at the “logical level”, we’re talking about layers that a user might create in a high-level framework, rather than the exact MXNet ops that would be used to implement them.

If you start with one of these logical nets, and then produce an MXSymbol from it, there is a particular property that I call “shape polymorphism” that it is very useful for this resulting MXSymbol to have. Specificially, this property that you can simply call MXSymbolInferShape, providing concrete values for all of the dynamic dimensions, and a valid MXSymbol will result.

Shape polymorphism is a highly desirable property because it gives a very fast way to take an initial dimension-agnostic MXSymbol and specialize it to produce a MXSymbol that will operate on a specific batch size, sequence length, image size, etc. All you need is to feed in the concrete values for the initial dynamic dimensions, and rely on shape inference to propagate these dynamic dimensions through the network.

Shape polymorphism also makes nets very straightforward to deploy in production, because the user of the net just needs to create an initial “master” MXSymbol from a single JSON specification and then specialize it as needed by calling MXSymbolShapeInfer by giving it the dimensions of the input tensors it will operate on (a very simple recipe).

If you don’t have the property of shape polymorphism, it is still possible to make new nets, but it is much more complicated, and involves templatizing the JSON, for example, which is slower, more error prone, harder to implement, harder to deploy and maintain, and harder to communicate.

Ok, I’ve described the background to this issue. Apologies if it was obvious, my experience is that not everyone encounters these issues of shape polymorphism in the particular work they do.

Now, specifically, what this new reshape operator does better than the previous reshape layer is to allow flattening of tensors that involve more than two dynamic dimensions.

Let’s first take an example we can handle with the existing reshape operator, which is something like (b, n, 3) being flattened to (b * n, 3). This is straightforward with a reshape op having shape (-1, 3).

But, if you have the shape (b, n, m, 3), and you wish to flatten this to (b * n, m, 3), then you cannot use the same technique to do this. If you use the -1 to combine b and n, you cannot use a second -1 to propogate m, and the 0 code refers to wrong position. The -3, -2 code sequence will work in this case but will not allow you to flatten 3 dimensions together simultaneously, so you’re only delaying the problem.

We could add more features to reshape, for example, we could add a code that flattens together more than two dimensions at a time, but it seems more reasonable to me to identify the whole concept of merging dimensions as an operation worthy of its own op rather than being shoe-horned into an existing operator via yet more obscure negative shape codes.

What do others think? Should we introduce new codes into reshape, which is already complex and hard to understand? Or should we introduce a new op that handles merging/permutation of axes and possibly even deprecate the -3 code in reshape?