I’d like to open discussion of the design of a simple reshaping layer that allows dimensions to be easily merged (by which I mean flattened together).
This would be the inverse of a layer extension that is currently under review, namely sub-setting support in
reshape_like: https://github.com/apache/incubator-mxnet/pull/11928. The extension implemented by that PR makes it very easy to split dimensions that have previously been merged. This proposal takes care of the inverse case: merging several dimensions together. This is a common operation when dealing with spatial or temporal sequences which we are applying the same operation to, where it is typical to fold these extra dimensions into the batch dimension, apply the operation, and the split that batch dimension back up again.
merge_dims op would allow different dimensions to be merged. It would take as a parameter a list of which axis of the input to send each axis of the output. If multiple input axes get mapped to the same output axis, they get merged. This list must be the same length as the rank of the input.
So, for example, this table shows on the left the input shape, in the middle the target_axes parameter, and on the right the output shape that the layer would produce:
[2, 3, 5] -> [0, 0, 0] ->  [2, 3, 5] -> [0, 0, 1] -> [6, 5] [2, 3, 5] -> [0, 1, 1] -> [2, 15] [2, 3, 5] -> [0, 1, 0] -> [10, 3] [2, 3, 5] -> [1, 0, 1] -> [3, 10] [2, 3, 5] -> [1, 0, 0] -> [15, 2] [2, 3, 5] -> [0, 1, 2] -> [2, 3, 5] [2, 3, 5] -> [2, 1, 0] -> [5, 3, 2]
Let’s explain in more detail one of these examples:
in_shape = [2,3,5] target_axes = [0, 1, 1] output_shape = [2, 15]
First, keep in mind that
target_axes uses zero indexing, so
0 refers to the first axis of the input,
1 refers to the second axis, etc.
target_axes means “send the first axis of the input to the first axis of the output”; therefore the first element of
output_shape is 2.
Next, the remaining
1, 1 of
target_axes means “send the second and third axis of the input to the second axis of the output”, therefore the second element of
This mechanism is quite flexible and allows for arbitrary re-ordering of the dimensions in addition to merging, which you can see in the final two rows of the table.
In Python, the exact semantics would be as follows:
def out_shape(in_shape, target_axes): assert(len(in_shape) == len(target_axes)) out_shape =  * (1 + max(target_axes)) assert(min(target_axes) >= 0) for (in_dim, target_axis) in zip(in_shape, target_axes): out_shape[target_axis] *= in_dim return out_shape
We will implement this if approved.