I am trying to use the graph structure of MXNet to speed up some calculations, and I am currently trying to mimic behavior that I have already implemented in PyTorch. However, I am confused on how to properly do this, be it with gluon.Trainer
or some other method.
To explain with an example, what I have working in PyTorch is the following (slightly modified to try to give the simplest example), and I want to translate this to MXNet.
import torch.optim
def unconstrained_fit(objective, data, pdf, init_pars, tolerance):
init_pars.requires_grad = True
optimizer = torch.optim.Adam([init_pars])
max_delta = None
n_epochs = 10000
for _ in range(n_epochs):
loss = objective(init_pars, data, pdf)
optimizer.zero_grad()
loss.backward()
init_old = init_pars.data.clone()
optimizer.step()
max_delta = (init_pars.data - init_old).abs().max()
if max_delta < tolerance:
break
return init_pars
As The Straight Dope points out in the PyTorch to MXNet cheatsheet, in MXNet one would usually be able to use a Trainer where one would use an optimizer in PyTorch. However, I don’t understand how to properly initialize the Trainer in my case, as where one would usually do something along the lines of
trainer = gluon.Trainer(net.collect_params(), 'adam')
I assume that I will need to collect the parameters myself as I don’t have a neural network that I want to use, but rather objective
that I want to minimize. I am confused on how to do this properly, as the below is obviously not correct.
import mxnet as mx
from mxnet import gluon, autograd
def unconstrained_fit(self, objective, data, pdf, init_pars, tolerance):
ctx = mx.cpu()
# How to properly do this chunck?
init_pars = mx.gluon.Parameter('init_pars',
shape=init_pars.shape,
init=init_pars.asnumpy().all)
init_pars.initialize(ctx=ctx)
optimizer = mx.optimizer.Adam()
trainer = gluon.Trainer([init_pars], optimizer)
###
max_delta = None
n_epochs = 10000
for _ in range(n_epochs):
with autograd.record():
loss = objective(init_pars, data, pdf)
loss.backward()
init_old = init_pars.data.clone()
trainer.step(data.shape[0])
max_delta = (init_pars.data - init_old).abs().max()
if max_delta < tolerance:
break
return init_pars
I am clearly misunderstanding something basic, so if anyone can point me to something clarifying that would be helpful. Even more helpful would be if someone understands what I am asking and is able to summarize why what I am doing is wrong.