In the gluon tutorial of pix2pix:
http://gluon.mxnet.io/chapter14_generative-adversarial-networks/pixel2pixel.html
the U-net defined in such way that we can’t save params using net.save_params(…)
I get the following error:
ValueError: Prefix dsds is to be striped before saving, but Parameter unetcenterblock0_encoderblock0_conv0_weight does not start with dsds. If you are using Block.save_params, This may be due to your Block shares parameters from other Blocks or you forgot to use with name_scope()
during init. Consider switching to Block.collect_params.save and Block.collect_params.load instead.
A minimal reproducible example:
def param_init(param, ctx):
if param.name.find('conv') != -1:
if param.name.find('weight') != -1:
param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
else:
param.initialize(init=mx.init.Zero(), ctx=ctx)
elif param.name.find('batchnorm') != -1:
param.initialize(init=mx.init.Zero(), ctx=ctx)
# Initialize gamma from normal distribution with mean 1 and std 0.02
if param.name.find('gamma') != -1:
param.set_data(nd.random_normal(1, 0.02, param.data().shape))
def network_init(net, ctx):
for param in net.collect_params().values():
param_init(param, ctx)
class EncoderBlock(nn.HybridBlock):
def __init__(self, out_channels, in_channels, is_center_block=False):
super(EncoderBlock, self).__init__()
with self.name_scope():
en_conv = nn.Conv2D(channels=out_channels, kernel_size=4, strides=2, padding=1,
in_channels=in_channels, use_bias=False)
en_relu = nn.LeakyReLU(alpha=0.2)
blocks_list = [en_conv, en_relu]
if not is_center_block:
en_norm = nn.BatchNorm(momentum=0.1, in_channels=out_channels)
blocks_list += [en_norm]
self.model = nn.HybridSequential()
with self.model.name_scope():
for block in blocks_list:
self.model.add(block)
def hybrid_forward(self, F, x):
return self.model(x)
class DecoderBlock(nn.HybridBlock):
def __init__(self, out_channels, in_channels):
super(DecoderBlock, self).__init__()
with self.name_scope():
de_relu = nn.Activation(activation='relu')
de_conv = nn.Conv2DTranspose(channels=out_channels, kernel_size=4, strides=2, padding=1,
in_channels=in_channels, use_bias=False)
de_norm = nn.BatchNorm(momentum=0.1, in_channels=out_channels)
blocks_list = [de_relu, de_conv, de_norm]
self.model = nn.HybridSequential()
with self.model.name_scope():
for block in blocks_list:
self.model.add(block)
def hybrid_forward(self, F, x):
return self.model(x)
class UnetCenterBlock(nn.HybridBlock):
def __init__(self, inner_channels, outer_channels):
super(UnetCenterBlock, self).__init__()
with self.name_scope():
encoder = EncoderBlock(in_channels=outer_channels, out_channels=inner_channels, is_center_block=True)
decoder = DecoderBlock(in_channels=inner_channels, out_channels=outer_channels)
blocks_list = [encoder, decoder]
self.model = nn.HybridSequential()
with self.model.name_scope():
for block in blocks_list:
self.model.add(block)
def hybrid_forward(self, F, x):
return F.concat(self.model(x), x, dim=1)
class UnetWrapBlock(nn.HybridBlock):
def __init__(self, inner_channels, outer_channels, inner_block=None, use_dropout=False, concat_input_and_output=True):
super(UnetWrapBlock, self).__init__()
self.concat_input_and_output = concat_input_and_output
with self.name_scope():
encoder = EncoderBlock(in_channels=outer_channels, out_channels=inner_channels)
decoder = DecoderBlock(in_channels=inner_channels * 2, out_channels=outer_channels)
blocks_list = [encoder] + [inner_block] + [decoder]
if use_dropout:
blocks_list += [nn.Dropout(rate=0.5)]
self.model = nn.HybridSequential()
with self.model.name_scope():
for block in blocks_list:
self.model.add(block)
def hybrid_forward(self, F, x):
if self.concat_input_and_output:
return F.concat(self.model(x), x, dim=1)
else:
return self.model(x)
net = UnetCenterBlock(3, 3)
net = UnetWrapBlock(3, 3, inner_block=net, concat_input_and_output=False)
ctx = mx.gpu()
param_filename = "unet.params"
network_init(net, ctx)
net.save_params(param_filename)