Clone module in Scala


#1

I have a requirement in my software design that requires me to use multiple instances of the same module in Scala (via org.apache.mxnet.module.Module). Is there a way to create a deep-copy of the model so that I don’t need to read the module from disk for each instance?

Thanks!


#2

You can use getParams() to get the parameters from the original module and use setParams() to set those parameters on the second module.


#3

Just to clarify, the params would not share memory? it’s ok if they do (even preferable) as long as that doesn’t introduce concurrency problems.

ps. something like this is ok?

          // load the model artifacts
          final Tuple3<Symbol, Map<String,NDArray>, Map<String,NDArray>> modelArtifacts =
              org.apache.mxnet.Model.loadCheckpoint(modelPath, 0);

          final Symbol symbol = modelArtifacts._1();
          final Map<String,NDArray> argParams = modelArtifacts._2();
          final Map<String,NDArray> auxParams = modelArtifacts._3();

          final List<Module> models = Arrays.asList(new Module[numModelCopies]);

          for(int i = 0; i < numModelCopies; ++i) {
              final Module module = new Module.Builder(symbol).setContext(context).build();

              // forTraining, inputNeedsGrad, forceRebind, dataDescs
              final List<DataDesc> dataDescs = getDataDescriptions();
              module.bind(false, false, false, JavaConverters.asScalaBufferConverter(dataDescs).asScala());

              // argParams, auxParams, allowMissing, forceInit, allowExtra
              module.setParams(argParams, auxParams, true, false, false);

              models.set(i, module);
          }