sym-API + multi-gpus: how to improve MNIST performance

Hi,
how can I improve the performance of MNIST hand-writings using sym-API?

import mxnet as mx
import numpy as np
import random
import time

from mxnet import autograd as ag
from mxnet.io import NDArrayIter
from mxnet.metric import Accuracy
from mxnet.optimizer import Adam
from mxnet.test_utils import get_mnist


BATCH_SIZE_PER_REPLICA = 512
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * 2
NUM_CLASSES = 10
EPOCHS = 10
# input image dimensions
IMG_ROWS, IMG_COLS = 28, 28

mx.random.seed(42)
random.seed(42)

mnist = get_mnist()
train_data = NDArrayIter(mnist['train_data'], mnist['train_label'], BATCH_SIZE, shuffle=True)
test_data = NDArrayIter(mnist['test_data'], mnist['test_label'], BATCH_SIZE)

data = mx.sym.Variable('data')

conv1 = mx.sym.Convolution(data=data, num_filter=32, kernel=(3,3))
relu1 = mx.sym.Activation(data=conv1, act_type="relu")
conv2 = mx.sym.Convolution(data=relu1, num_filter=64, kernel=(3,3))
relu2 = mx.sym.Activation(data=conv2, act_type="relu")
pool = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2))
dropout1 = mx.sym.Dropout(data=pool, p=0.25)
flatten = mx.sym.Flatten(data=dropout1)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=128)
relu3 = mx.sym.Activation(data=fc1, act_type="relu")
dropout2 = mx.sym.Dropout(data=relu3, p=0.5)
fc2 = mx.sym.FullyConnected(data=dropout2, num_hidden=NUM_CLASSES)
net = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
        
devices = (mx.gpu(0), mx.gpu(1))
mod = mx.mod.Module(net, context=devices)
    
mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
                
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(kvstore='device', optimizer=Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08))
                
start = time.perf_counter()
mod.fit(train_data, num_epoch=EPOCHS)
elapsed = time.perf_counter() - start
print('elapsed: {:0.3f}'.format(elapsed))

metric = mx.metric.Accuracy()
mod.score(test_data, metric)
print('validation acc: %s=%f' % metric.get())