Run time is different between python and c++?

I built mxnet by source using MKLDNN and Openblas. And installed python-mxnet by pip install mxnet-cu100.
Then I test inference runtime with the same model and same image.
My CPU is 4 cores Intel(R) Core™ i5-7500 CPU @ 3.40GHz.
I tested by clock().

python-mxnet : forward runtime 0.0045s
c++ : forward runtime 0.7584s

My CPP code refered to image-classification. I am
surprised by the order of magnitude difference.

MXPredGetOutput(pred_hnd, output_index, &(data[0]), static_cast<mx_uint>(size));This line wasted 90% of time.

I think maybe it’s problem of my conpiled flag.the commend is:

mkdir build; cd build
cmake -DUSE_CUDA=0 -DUSE_MKLDNN=1 -DUSE_CPP_PACKAGE=1 -GNinja ..
ninja -v

Someone can help me?

Hello @XiaXuehai,
I also built the C++ MXNET package with Intel MKL on my
Intel® Core™ i5-8250U CPU @ 1.60GHz × 8 .

For me the runtime inference speed tested on batch-size 1 between python and c++ is very similar.
I use the following inference C++ example for measuring the runtime:

I changed Predictor::PredictImage() to measure the elapsed time:

/*
 * The following function runs the forward pass on the model.
 * The executor is created in the constructor.
 *
 */
void Predictor::PredictImage(const std::string& image_file) {
  // Load the input image
  NDArray image_data = LoadInputImage(image_file);

  // Normalize the image
  image_data.Slice(0, 1) -= mean_image_data;

  LG << "Running the forward pass on model to predict the image";
  /*
   * The executor->arg_arrays represent the arguments to the model.
   *
   * Copying the image_data that contains the NDArray of input image
   * to the arg map of the executor. The input is stored with the key "data" in the map.
   *
   */

  int best_idx;
  float best_accuracy;

  // warm-up
  for (int i = 0; i < 10; ++i) {
  image_data.CopyTo(&(executor->arg_dict()["data"]));

  // Run the forward pass.
  executor->Forward(false);

  // The output is available in executor->outputs.
  auto array = executor->outputs[0].Copy(Context::cpu());

  /*
   * Find out the maximum accuracy and the index associated with that accuracy.
   * This is done by using the argmax operator on NDArray.
   */
  auto predicted = array.ArgmaxChannel();

  /*
   * Wait until all the previous write operations on the 'predicted'
   * NDArray to be complete before we read it.
   * This method guarantees that all previous write operations that pushed into the backend engine
   * for execution are actually finished.
   */
  predicted.WaitToRead();

  best_idx = predicted.At(0, 0);
  best_accuracy = array.At(0, best_idx);
  }

  // runtime
  std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
  for (int i = 0; i < 1000; ++i) {
  image_data.CopyTo(&(executor->arg_dict()["data"]));

  // Run the forward pass.
  executor->Forward(false);

  // The output is available in executor->outputs.
  auto array = executor->outputs[0].Copy(Context::cpu());

  /*
   * Find out the maximum accuracy and the index associated with that accuracy.
   * This is done by using the argmax operator on NDArray.
   */
  auto predicted = array.ArgmaxChannel();

  /*
   * Wait until all the previous write operations on the 'predicted'
   * NDArray to be complete before we read it.
   * This method guarantees that all previous write operations that pushed into the backend engine
   * for execution are actually finished.
   */
  predicted.WaitToRead();

  best_idx = predicted.At(0, 0);
  best_accuracy = array.At(0, best_idx);
  }
  std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now();
  std::cout << "Elapsed time: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count() << "ms" << std::endl;

  if (output_labels.empty()) {
    LG << "The model predicts the highest accuracy of " << best_accuracy << " at index "
       << best_idx;
  } else {
    LG << "The model predicts the input image to be a [" << output_labels[best_idx]
       << " ] with Accuracy = " << best_accuracy << std::endl;
  }
}

You can use unit_test_inception_inference.sh to download the inception model.
To run the program use:

./inception_inference --symbol "./model/Inception-BN-symbol.json" --params "./model/Inception-BN-0126.params" --synset "./model/synset.txt" --mean "./model/mean_224.nd" --image "./model/dog.jpg"

I adapted the code from How to use MXNet-TensorRT integration to load the model in python and measure its runtime:

import mxnet as mx
from mxnet.gluon.model_zoo import vision
import time
import os
 
batch_shape = (1, 3, 224, 224)
sym, arg_params, aux_params = mx.model.load_checkpoint('model/Inception-BN', 126)

# Create sample input
input = mx.nd.zeros(batch_shape)

print('Building engine')
arg_params.update(aux_params)
all_params = dict([(k, v.as_in_context(mx.cpu())) for k, v in arg_params.items()])
executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.cpu(), all_params=all_params,
                                             data=batch_shape, grad_req='null', force_rebind=True)
# Warmup
print('Warming up')
for i in range(0, 10):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()
    
# Timing
print('Starting timed run')
start = time.time()
for i in range(0, 1000):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()
    best_idx = y_gen[0].argmax()
    best_accuracy = y_gen[0][best_idx]
end = time.time()
print("Elapsed time: %.3fms" % ((end - start) * 1000))
C++ Debug Build
Elapsed time: 24044ms
C++ Release Build
Elapsed time: 22397ms
Python Runtime
Elapsed time: 22356.505ms

Also, you should use
pip install mxnet-mkl==1.4.1 or
pip install mxnet-cu100mkl==1.4.1 for python.
If you use pip install mxnet-cu100, you could be measuring the runtime on your gpu or for cpu without using Intel MKL.

Best,
QueensGambit

Thank you.
But I run the model on the cpu.If I run the model on the GPU, the python runtime is 10x faster than CPU. If I did not use MKL on python version,it’s runtime should be slower than c++ which built by MKLdnn.

I test your code, it wastes about 60s~70s both python and cpp on my 4 cores cpu.

my model is crnn.
my cpp code:

  clock_t t0, t1;
  t0 = clock();
  // Run the forward pass to predict the image.
  executor->Forward(false);
  auto array = executor->outputs[0].Copy(Context::cpu());
  auto predicted = array.ArgmaxChannel();
  predicted.WaitToRead();
  t1 = clock();
  std::cout << "time = " << (double)(t1 - t0) / CLOCKS_PER_SEC << std::endl;

python code:

net.hybridize()
t0 = time()
out = net(input)
max = out.squeeze(axis=1).argmax(axis=1)
print(out.shape)
print('time = ', time()-t0)

Hmm,
that you have a similar run-time on my code example is a good sign.
For measuring the run-time in python I would also include a

.wait_to_read()

in your python code to be more comparable with the C++ version.
Also you can include a warm-up phase which isn’t included in the time run-time measurement.

I haven’t dealt with RNNs in MXNET before and don’t know your CRNN architecture, but you might need executor_buckets.
Here’s an official example for RNN inference in C++:

You can try out this example, measure its run-time similar to my code above and compare it with python on the same model.

I added .wait_to_read(). Now the runtime is:

python-mxnet : forward runtime 0.2118s
c++ : forward runtime 0.7584s

Still have big difference.
I will try the example.

I noticed a crucial bug in both inference C++ examples.
It seems to me that the gradients are actually computed during inference.

I tried out IntelMKL’s graph optimization feature by setting:

export MXNET_SUBGRAPH_BACKEND=MKLDNN

and then I noticed the following message for running inference in C++:

src/executor/graph_executor.cc:1807: Skip subgraph MKLDNN convolution optimization pass as it requires `grad_req=null`.

So I replaced

executor = net.SimpleBind(global_ctx, args_map, std::map<std::string, NDArray>(), std::map<std::string, OpReqType>(), aux_map);

with

  std::vector<NDArray> arg_arrays;
  std::vector<NDArray> grad_arrays;
  std::vector<OpReqType> grad_reqs;
  std::vector<NDArray> aux_arrays;

  net.InferExecutorArrays(global_ctx, &arg_arrays, &grad_arrays, &grad_reqs,
                          &aux_arrays, args_map, std::map<std::string, NDArray>(),
                          std::map<std::string, OpReqType>(), aux_map);
  for (size_t i = 0; i < grad_reqs.size(); ++i) {
      grad_reqs[i] = kNullOp;
  }
  executor = net.Bind(global_ctx, arg_arrays, grad_arrays, grad_reqs, aux_arrays,
                                       std::map<std::string, Context>(), nullptr);

Probably there’s a more elegant way to disable gradients for executors in C++, but this should definitely be fixed. Now the graph optimization is actually done and the inference time is reduced from 22397ms to

C++ Release Build
Elapsed time: 16769ms
1 Like

Thanks a lot.
I figure out that the clock() is not correct.
Replace it with:

std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();

now runtime (100 times):

C++ Release Build:      Elapsed time: 20285ms, clock time 80574ms
Python Runtime :         Elapsed time: 20533ms
Python mxnet-mkl Runtime:  Elapsed time: 8680ms

I tried out MKL which you said above. There is no improvement.
I don’t know why clock time is not correct?

1 Like

I got a notable improvement through using @QueensGambit 's replacement method with my C++ code. Always curious whether the MKL has been fully enabled.