Hi @simomaur,
I just tried this out and my weights do change. I’m working with a LeNet model for MNIST classification on a single GPU to create a simple case for us to test this out. I’m also using MXNet version 1.2.0.
My code first trains a HybridBlock model (the ‘pre-training’ stage), exports the weights and network architecture, loads it back in as a hybrid block and continues training (the ‘fine-tuning’ stage). In this final stage, I print out the weights of a specific layer (e.g. “hybridsequential0_dense1_weight”) before the fine-tuning and after and they are different. I train for a single epoch.
Would be interested to see how what we have is different, please let me know if you spot the difference. I notice you’re loading in parameters differently to me. I don’t separate the arg and aux params, so try changing this (assuming you’re not loading an ONNX model).
sym = mx.sym.load('lenet-symbol.json')
deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
deserialized_net.collect_params().load('lenet-0001.params', ctx=ctx)
My complete example is as follows;
from __future__ import print_function
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import nd, autograd, gluon
from mxnet.gluon.data.vision import transforms
import numpy as np
# Build a simple convolutional network
def build_lenet(net):
with net.name_scope():
# First convolution
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
# Second convolution
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
# Flatten the output before the fully connected layers
net.add(gluon.nn.Flatten())
# First fully connected layers with 512 neurons
net.add(gluon.nn.Dense(512, activation="relu"))
# Second fully connected layer with as many neurons as the number of classes
net.add(gluon.nn.Dense(num_outputs))
return net
# Train a given model using MNIST data
def train_model(model, inspect_param=None):
# Use cross entropy loss
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# Use Adam optimizer
model_params = model.collect_params()
trainer = gluon.Trainer(model_params, 'adam', {'learning_rate': .001})
if inspect_param: print("Starting values: {}".format(model_params[inspect_param].data()))
# Train for one epoch
for epoch in range(1):
# Iterate through the images and labels in the training data
for batch_num, (data, label) in enumerate(train_data):
# get the images and labels
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
# Ask autograd to record the forward pass
with autograd.record():
# Run the forward pass
output = model(data)
# Compute the loss
loss = softmax_cross_entropy(output, label)
# Compute gradients
loss.backward()
# Update parameters
trainer.step(data.shape[0])
# Print loss once in a while
if batch_num % 50 == 0:
curr_loss = nd.mean(loss).asscalar()
print("Epoch: %d; Batch %d; Loss %f" % (epoch, batch_num, curr_loss))
if inspect_param: print("Finishing values: {}".format(model_params[inspect_param].data()))
print("Stage 0: setup")
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
num_inputs = 784
num_outputs = 10
batch_size = 64
train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor()),
batch_size, shuffle=True)
print("Stage 1: pre-train model")
net = build_lenet(gluon.nn.HybridSequential())
net.hybridize()
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
train_model(net)
print("Stage 2: save model")
net.export("lenet", epoch=1)
print("Stage 3: load model")
sym = mx.sym.load('lenet-symbol.json')
deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
deserialized_net.collect_params().load('lenet-0001.params', ctx=ctx)
print("Stage 4: fine-tuning")
train_model(deserialized_net, inspect_param="hybridsequential0_dense1_weight")
And my output shows the weights of “hybridsequential0_dense1_weight” change;
Stage 0: setup
Stage 1: pre-train model
Epoch: 0; Batch 0; Loss 2.280598
Epoch: 0; Batch 50; Loss 0.337641
Epoch: 0; Batch 100; Loss 0.210663
Epoch: 0; Batch 150; Loss 0.111363
Epoch: 0; Batch 200; Loss 0.115460
Epoch: 0; Batch 250; Loss 0.116974
Epoch: 0; Batch 300; Loss 0.062213
Epoch: 0; Batch 350; Loss 0.093351
Epoch: 0; Batch 400; Loss 0.048531
Epoch: 0; Batch 450; Loss 0.058199
Epoch: 0; Batch 500; Loss 0.070505
Epoch: 0; Batch 550; Loss 0.041834
Epoch: 0; Batch 600; Loss 0.059377
Epoch: 0; Batch 650; Loss 0.053189
Epoch: 0; Batch 700; Loss 0.017252
Epoch: 0; Batch 750; Loss 0.027031
Epoch: 0; Batch 800; Loss 0.005970
Epoch: 0; Batch 850; Loss 0.095129
Epoch: 0; Batch 900; Loss 0.013246
Stage 2: save model
Stage 3: load model
Stage 4: fine-tuning
Starting values:
[[-0.1280341 0.0091968 0.0778511 ..., -0.08914039 0.05862502
-0.1317472 ]
[-0.0269692 0.04817459 0.10786245 ..., -0.08930212 -0.03623529
0.10491013]
[ 0.08288375 -0.09571287 -0.03237296 ..., 0.06498121 0.08183339
-0.04402027]
...,
[ 0.11866657 0.05870249 -0.03701728 ..., 0.04017895 0.0422365
-0.08687742]
[-0.06393342 0.09521717 -0.04818693 ..., -0.06093055 0.03657702
-0.04621929]
[ 0.04993053 -0.04644486 0.01558479 ..., 0.0980444 -0.05688887
-0.11486824]]
<NDArray 10x512 @gpu(0)>
Epoch: 0; Batch 0; Loss 0.002895
Epoch: 0; Batch 50; Loss 0.017847
Epoch: 0; Batch 100; Loss 0.047905
Epoch: 0; Batch 150; Loss 0.025534
Epoch: 0; Batch 200; Loss 0.039929
Epoch: 0; Batch 250; Loss 0.143031
Epoch: 0; Batch 300; Loss 0.021253
Epoch: 0; Batch 350; Loss 0.041503
Epoch: 0; Batch 400; Loss 0.059370
Epoch: 0; Batch 450; Loss 0.096992
Epoch: 0; Batch 500; Loss 0.016127
Epoch: 0; Batch 550; Loss 0.051358
Epoch: 0; Batch 600; Loss 0.010862
Epoch: 0; Batch 650; Loss 0.052501
Epoch: 0; Batch 700; Loss 0.036348
Epoch: 0; Batch 750; Loss 0.028365
Epoch: 0; Batch 800; Loss 0.077076
Epoch: 0; Batch 850; Loss 0.094364
Epoch: 0; Batch 900; Loss 0.022251
Finishing values:
[[-0.13683593 -0.01088019 0.08661763 ..., -0.08079872 0.06584116
-0.17972106]
[-0.00970386 0.06387555 0.12025371 ..., -0.08949943 -0.04318072
0.11020051]
[ 0.08180709 -0.12747169 -0.04584555 ..., 0.0604701 0.08880944
-0.04270781]
...,
[ 0.11766645 0.05050235 -0.08242993 ..., 0.04004622 0.04245359
-0.12155535]
[-0.09150746 0.10231887 -0.09706685 ..., -0.06136883 0.04366267
-0.04674431]
[ 0.04564926 -0.03896103 -0.01093464 ..., 0.09354883 -0.05801368
-0.13949625]]
<NDArray 10x512 @gpu(0)>