Mx.nd.argmax slow on GPU with high reduction dimensions

When doing mx.nd.argmax over ndarrays with high reduction dimensions, speed drops considerably when you compare it to e.g. mx.nd.max.

This effect only occurs on the GPU and when high reductions dimensions are used (over 10k).
Is this expected behaviour? This is happening with the pip package mxnet-cu91, which currently uses mxnet 1.2.

Here is a small python script that compares the performance of both operators over a range of reduction dimensions:

import time
import mxnet as mx 

def max(x, ctx):
    return mx.nd.max(x, axis=1)

def argmax(x, ctx):
    return mx.nd.argmax(x, axis=1)

def measure_time(func, iters, inputs, ctx):

    begin = time.time()
    for i in range(iters):
        result = func(inputs[i,:,:], ctx=ctx)
        result.wait_to_read()

    return time.time() - begin

ctx = mx.gpu()
batch_size = 32
iterations = 500
for reduction_dimension in [25, 50, 100, 1000, 10000, 100000]:
    print('reduction dimension: {}'.format(reduction_dimension))
    inputs = mx.nd.random_uniform(0, 100,
                                  shape=(iterations, batch_size, reduction_dimension),
                                  ctx=ctx)

    t = measure_time(max, iterations, inputs, ctx)
    print("max took {} seconds".format(t))

    t = measure_time(argmax, iterations, inputs, ctx)
    print("argmax took {} seconds".format(t))
    print('')

It does look weird. I have submitted a github issue, so someone could look into that - https://github.com/apache/incubator-mxnet/issues/11337

Thanks for looking into it and creating an issue.