MxNet ONNX loder does not support Embedding layer?


#1

Hi, I was trying to load a model trained on PyTorch using ONNX.
I exported the PyTorch model as ONNX file, and loaded the file from MxNet.

I found that models including Embedding layer cannot be imported to MxNet.
(If the model includes Convolution only, importing from ONNX works well.)

Here is a minimal example, to see the error.

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.onnx as torch_onnx

class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()
        self.embed = nn.Embedding(10, 100)

    def forward(self, inputs):
        h0 = self.embed(inputs)
        return h0

# Use this an input trace to serialize the model
input_shape = (30,1,1) 
model_onnx_path = "torch_model2.onnx"
model = Model2()
model.train(False)

# Export the model to an ONNX file
dummy_input = Variable(torch.randint(10, (1,) + input_shape).long())
output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=True)
print("Export of torch_model.onnx complete!")


import mxnet as mx
import mxnet.contrib.onnx as onnx_mxnet

sym, arg, aux = onnx_mxnet.import_model('torch_model2.onnx')
data_names = [graph_input for graph_input in sym.list_inputs()
                  if graph_input not in arg and graph_input not in aux]

mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[("0", (1,) + input_shape)], label_shapes=None)

Exported ONNX file is like this

graph(%0 : Long(1, 30, 1, 1)
      %1 : Float(10, 100)) {
  %2 : Float(1, 30, 1, 1, 100) = onnx::Gather(%1, %0), scope: Model2/Embedding[embed]
  return (%2);
}

Error message is like this.

MXNetError: [16:21:32] src/executor/graph_executor.cc:481: 
InferType pass cannot decide dtypes for the following arguments (-1 means unknown dtype).     
Please consider providing them as inputs: 1: -1,

If this simple model does not work, it means that the ONNX support of MxNet cannot be used for NLP applications. Because almost all the NN models for NLP have embedding layer.


#2

ONNX actually doesn’t have such operator as Embedding. In your example, PyTorch exports embedding as onnx::Gather operator .

onnx::Gather operator is now supported in master version of MXNet, so just remove your current version and install the master version as explained here: https://mxnet.incubator.apache.org/install/index.html?platform=Linux&language=Python&processor=CPU&version=master