Parameter 'XXX' was not initialized on context cpu(0). It was only initialized on [gpu(0), gpu(1), gpu(2), gpu(3)]


#1

Dear all,

I am trying to implement a conv network (that includes custom constants) with data parallelization (4 gpus) and I am getting an error I cannot understand how to solve. With one gpu the code is working fine, with many it breaks down.

To cut the long story short, when I have a convolution operation, instead of using the weights (3x3) I “deproject” them in a higher dimensional grid (say “9x9”), so there is an nd.dot(weight,deprojection_matrix) somewhere in the code. In my network I provide these deprojection_matrix using a custom initializer so it will be initialized in all contexts. This is how I define the contexts:

import os

print ("------------------------------------------")
print ("mxnet version:= {}".format(mx.__version__))
print ("------------------------------------------")
#  ************ Run on all available node GPUs **************** 
gpus = [int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(',')]
ctx = [mx.gpu(i)  for i in gpus] # DOESN'T WORK for 4 gpus
# ctx = [mx.gpu(0)] # WORKS

This is my forward_backward step, with a delay rate, following the example from gluon for data parallelization. I think this is where the problem lies.

delay_rate = 8 # delay rate for averaging the gradients 
def forward_backward_step(_iteration, _nbatch, _net, _data, _label):
    with autograd.record():
        # Here jacc_idx is my loss, the jaccard index 
        losses = [ 1.0 - jacc_idx(_net(X),Y) for X, Y in zip(_data, _label)]
    for l in losses: # Evaluate gradients in each ctx 
        l.backward()

    # This updates gradients across ALL devices, by first aggregating them. <3 Gluon!
    if (_iteration % delay_rate == 0):
        trainer.step(_nbatch * delay_rate)  # _data are from split_and_load, therefore, the 2nd axis is the batch size. 
        for param in _net.collect_params().values():
            param.zero_grad()

    return losses


def avg_jacc(_losses):
    tsum = 0.0
    for l in _losses:
        tsum += nd.mean(l).asscalar()

    return tsum / float(len(_losses))

this is the custom initializer:


@mx.init.register
class BSplineInit(mx.initializer.Initializer):
    
    def __init__(self,
                 kernel_size = 3, 
                 kernel_effective_size = 5, 
                 degree = 2,
                 **kwards):
        mx.initializer.Initializer.__init__(self,**kwards)
    

        self.kernel_eff_ = kernel_effective_size
        myBSpline = BSpline_kernel(kernel_size=kernel_size, degree=degree)
        self.Bijkl = nd.array(myBSpline.get_spline_kernel(self.kernel_eff_))
    
        
    
    # This is where I put the BSpline initialization of the weights to the arr variable 
    def _init_weight(self,name,arr):
        arr[:] = self.Bijkl

this is the custom convolution operation that uses “deprojection_matrix”

class Conv2DS(Block):
     def __init__(self,  nfilters, nchannels=0, kernel_size = 3, kernel_effective_size = 5, use_bias = False, padding = (0,0), **kwards):
        Block.__init__(self,**kwards)
        
        self.nchannels = nchannels
        self.nfilters = nfilters
        self.kernel_size = kernel_size
        self.kernel_eff = kernel_effective_size
        self.use_bias = use_bias
        
        # Some custom operation that creates a "deprojection" matrix, for now a simple random NDArray
        # Ensures padding = 'SAME' for ODD kernel selection                                                                                                             
        p0 =  int((self.kernel_eff - 1)/2)                                                                                                                           
        p1 =  int((self.kernel_eff - 1)/2)                                                                                                                           
        self.pad = (p0,p1)
    
            
            
        with self.name_scope():
            

            self.weight = self.params.get(
                'weight',allow_deferred_init=True,
                shape=(self.nfilters,self.nchannels,self.kernel_size * self.kernel_size))
        
                
            if self.use_bias:
                self.bias = self.params.get(
                    'bias',allow_deferred_init=True,
                    init = mx.init.Zero(),
                    shape=(self.nfilters,))

    
    def forward(self, _x, _Bijkl):
        """
        _Bijkl is the deprojection_matrix 
        """
        
        proj_weight = nd.dot(self.weight.data(), _Bijkl.data())        
        if self.use_bias:
            conv = nd.Convolution(data=_x,
                                  weight=proj_weight,
                                  bias=self.bias.data(),
                                  no_bias=False,
                                  num_filter=self.nfilters,
                                  kernel=[self.kernel_eff,self.kernel_eff],
                                  pad = self.pad)
        
        else : 
            conv = nd.Convolution(data=_x,
                                  weight=proj_weight,
                                  no_bias=True,
                                  num_filter=self.nfilters,
                                  kernel=[self.kernel_eff,self.kernel_eff],
                                  pad = self.pad)

        
        return conv    
    
  

And this is a part of the network (a modified UNet architecture) that includes the usage of these constant “deprojection” matrices.


class ResUNet_sincept_d5(Block):
    def __init__(self, nfilters_init,  NClasses, kernel= [3]*4, kernel_eff = [7,19,31], degree=2,  verbose=True, norm_type = 'BatchNorm', **kwards):
        Block.__init__(self,**kwards)

        self.model_name = "ResUNet_sincept_d5"
        self.depth = 5
        self.nfilters = nfilters_init # Initial number of filters 
        self.NClasses = NClasses

        with self.name_scope():

            # This is a list of the deprojection matrices
            self.Bijkl = [] # gluon.nn.Sequential()
            for i in range(len(kernel_eff)):
                k = kernel[i+1]
                kef = kernel_eff[i]
                self.Bijkl += [self.params.get('bijkl_'+str(i)+'_',allow_deferred_init=True,
                                    init = BSplineInit(k, kef, degree),
                                    grad_req='null',
                                    differentiable=False,
                                    shape=(k*k ,kef,kef))]
             
               # many more terms 

    # Example usage: 
    def forward(self,_input):

        # First convolution 
        conv1 = self.conv_first_normed(_input)
        conv1 = nd.relu(conv1)

        # This is how I am using them: 
        Dn1 = self.Dn1(conv1, self.Bijkl)
        pool1 = self.pool1(Dn1)

Now, as said this model is running with ctx = [mx.gpu(0)], however when using all gpus I get the following error:

Traceback (most recent call last):
  File "phaino_resunet_parallel.py", line 187, in <module>
    losses = forward_backward_step( i, Nbatch, mynet, data, label)
  File "phaino_resunet_parallel.py", line 130, in forward_backward_step
    losses = [ 1.0 - jacc_idx(_net(X),Y) for X, Y in zip(_data, _label)]
  File "phaino_resunet_parallel.py", line 130, in <listcomp>
    losses = [ 1.0 - jacc_idx(_net(X),Y) for X, Y in zip(_data, _label)]
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 403, in forward
    Dn1 = self.Dn1(conv1, self.Bijkl)
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 269, in forward
    x = x + self.net[1+i](_xl,Bijkl[i])
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 212, in forward
    x = self.conv1(x,Bijkl)
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 153, in forward
    proj_weight = nd.dot(self.weight.data(), _Bijkl.data())        
  File "/home/dia021/Software/mxnet/gluon/parameter.py", line 390, in data
    return self._check_and_get(self._data, ctx)
  File "/home/dia021/Software/mxnet/gluon/parameter.py", line 175, in _check_and_get
    self.name, str(ctx), str(self._ctx_list)))
RuntimeError: Parameter 'resunet_sincept_d50_resnet_sincept_unit0_resnet_sblock0__conv1_weight' was not initialized on context cpu(0). It was only initialized on [gpu(0), gpu(1), gpu(2), gpu(3)].

I cannot understand the problem, I would be indebted for some help. Thank you for your time.


#2

From the error, it looks like MXNet is expecting resunet_sincept_d50_resnet_sincept_unit0_resnet_sblock0__conv1_weight to be initialized on cpu. This could be because the other operand is initialized on cpu.

I see that nd.dot(self.weight.data(), _Bijkl.data()) caused the exception. Is it possible _Bijkl.data() is on cpu? You could just print the context of those variables at the beginning of forward to confirm:

print(_Bijkl.data().context)
print(self.weight.data().context)

#3

Thank you very much for your reply @indu. I did the test you suggested, I added the following lines in the forward call of the Conv2DS layer:

print ("=========================")
print ("Context of Variables")
print ("weight::{0}, Bijkl::{1}".format(self.weight.data().context, _Bijkl.data().context ))
print ("=========================")

When I use ctx = [mx.gpu(0)] I get the correct context printed on screen. When I use

gpus = [int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(',')]
ctx = [mx.gpu(i)  for i in gpus]

I get the following error (which I can’t really make much sense of it :frowning: ):

=========================
Context of Variables
Traceback (most recent call last):
  File "phaino_resunet_parallel.py", line 187, in <module>
    losses = forward_backward_step( i, Nbatch, mynet, data, label)
  File "phaino_resunet_parallel.py", line 130, in forward_backward_step
    losses = [ 1.0 - jacc_idx(_net(X),Y) for X, Y in zip(_data, _label)]
  File "phaino_resunet_parallel.py", line 130, in <listcomp>
    losses = [ 1.0 - jacc_idx(_net(X),Y) for X, Y in zip(_data, _label)]
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 409, in forward
[11:29:39] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
    Dn1 = self.Dn1(conv1, self.Bijkl)
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 275, in forward
    x = x + self.net[1+i](_xl,Bijkl[i])
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 218, in forward
    x = self.conv1(x,Bijkl)
  File "/home/dia021/Software/mxnet/gluon/block.py", line 414, in __call__
    return self.forward(*args)
  File "/home/dia021/Software/phaino/models/resunet_d5_sincept.py", line 155, in forward
    print ("weight::{0}, Bijkl::{1}".format(self.weight.data().context, _Bijkl.data().context ))
  File "/home/dia021/Software/mxnet/gluon/parameter.py", line 390, in data
    return self._check_and_get(self._data, ctx)
  File "/home/dia021/Software/mxnet/gluon/parameter.py", line 175, in _check_and_get
    self.name, str(ctx), str(self._ctx_list)))
RuntimeError: Parameter 'resunet_sincept_d50_resnet_sincept_unit0_resnet_sblock0__conv1_weight' was not initialized on context cpu(0). It was only initialized on [gpu(0), gpu(1), gpu(2), gpu(3)].

Do you think it is possible that this is a mxnet bug? I am going to try and change dot with nd.FullyConnected and see what happens.

Again, thank you very much for the help.


#4

So it seems this must be somehow connected to mxnet nd.dot operator. I changed it with F.FullyConnected (changed all blocks --> hybridblocks as well) and now everything runs on 4 gpus. I’ll try to reconstruct a minimal reproducable example to see if this is a bug and report it.

Many thanks for all the help