I am trying to use the graph structure of MXNet to speed up some calculations, and I am currently trying to mimic behavior that I have already implemented in PyTorch. However, I am confused on how to properly do this, be it with
gluon.Trainer or some other method.
To explain with an example, what I have working in PyTorch is the following (slightly modified to try to give the simplest example), and I want to translate this to MXNet.
import torch.optim def unconstrained_fit(objective, data, pdf, init_pars, tolerance): init_pars.requires_grad = True optimizer = torch.optim.Adam([init_pars]) max_delta = None n_epochs = 10000 for _ in range(n_epochs): loss = objective(init_pars, data, pdf) optimizer.zero_grad() loss.backward() init_old = init_pars.data.clone() optimizer.step() max_delta = (init_pars.data - init_old).abs().max() if max_delta < tolerance: break return init_pars
As The Straight Dope points out in the PyTorch to MXNet cheatsheet, in MXNet one would usually be able to use a Trainer where one would use an optimizer in PyTorch. However, I don’t understand how to properly initialize the Trainer in my case, as where one would usually do something along the lines of
trainer = gluon.Trainer(net.collect_params(), 'adam')
I assume that I will need to collect the parameters myself as I don’t have a neural network that I want to use, but rather
objective that I want to minimize. I am confused on how to do this properly, as the below is obviously not correct.
import mxnet as mx from mxnet import gluon, autograd def unconstrained_fit(self, objective, data, pdf, init_pars, tolerance): ctx = mx.cpu() # How to properly do this chunck? init_pars = mx.gluon.Parameter('init_pars', shape=init_pars.shape, init=init_pars.asnumpy().all) init_pars.initialize(ctx=ctx) optimizer = mx.optimizer.Adam() trainer = gluon.Trainer([init_pars], optimizer) ### max_delta = None n_epochs = 10000 for _ in range(n_epochs): with autograd.record(): loss = objective(init_pars, data, pdf) loss.backward() init_old = init_pars.data.clone() trainer.step(data.shape) max_delta = (init_pars.data - init_old).abs().max() if max_delta < tolerance: break return init_pars
I am clearly misunderstanding something basic, so if anyone can point me to something clarifying that would be helpful. Even more helpful would be if someone understands what I am asking and is able to summarize why what I am doing is wrong.