When doing mx.nd.argmax over ndarrays with high reduction dimensions, speed drops considerably when you compare it to e.g. mx.nd.max.
This effect only occurs on the GPU and when high reductions dimensions are used (over 10k).
Is this expected behaviour? This is happening with the pip package mxnet-cu91, which currently uses mxnet 1.2.
Here is a small python script that compares the performance of both operators over a range of reduction dimensions:
import time
import mxnet as mx
def max(x, ctx):
return mx.nd.max(x, axis=1)
def argmax(x, ctx):
return mx.nd.argmax(x, axis=1)
def measure_time(func, iters, inputs, ctx):
begin = time.time()
for i in range(iters):
result = func(inputs[i,:,:], ctx=ctx)
result.wait_to_read()
return time.time() - begin
ctx = mx.gpu()
batch_size = 32
iterations = 500
for reduction_dimension in [25, 50, 100, 1000, 10000, 100000]:
print('reduction dimension: {}'.format(reduction_dimension))
inputs = mx.nd.random_uniform(0, 100,
shape=(iterations, batch_size, reduction_dimension),
ctx=ctx)
t = measure_time(max, iterations, inputs, ctx)
print("max took {} seconds".format(t))
t = measure_time(argmax, iterations, inputs, ctx)
print("argmax took {} seconds".format(t))
print('')