Does anyone know much about ConvRNN?
I need to use ConvGRU(symbol interface) to complete a task. But I just learn to use contrib.rnn.Conv2DGRUCell . How can I use the interface mx.rnn.ConvGRUCell()(code)?
my code detail:
def data_process(batch_size):
train_lst = ‘/home/feywell/demo/train.lst’
val_lst = ‘/home/feywell/demo/val.lst’
train_data = ImageSeqIter(
path_imglist=train_lst,
data_shape=(3,512,512),
label_shape=(3,512,512),
resize=512,
mean=np.array([[0.4914],[0.4822],[0.4465]]),
std=np.array([[0.2023],[0.1994],[0.2010]]),
data_name=‘data’,
label_name=‘label’,
batch_size=batch_size,
rand_crop = False,
rand_mirror = False,
)valid_data = ImageSeqIter(
path_imglist=val_lst,
data_shape=(3,512,512),
label_shape=(3,512,512),
resize=512,
mean=np.array([[0.4914],[0.4822],[0.4465]]),
std=np.array([[0.2023],[0.1994],[0.2010]]),
data_name=‘data’,
label_name=‘label’,
batch_size=batch_size,
rand_crop = False,
rand_mirror = False
)return train_data,valid_data
train_loader,val_loader = data_process(4)test net
input_shape = (4,3,512,512)
ctx = mx.gpu()
data = mx.sym.Variable(‘data’)
states = mx.sym.Variable(‘states’)net = ConvGRUCell(input_shape=input_shape, num_hidden=12,i2h_kernel=(3,3), h2h_kernel=(3,3),i2h_pad=(1,1))
print(net)
output,states = net(data,states)
print(output)
print(output.list_arguments())
print(states)
model = mx.mod.Module(symbol=output, context=ctx,label_names=None)model.fit(train_loader, # train data
eval_data=val_loader, # validation data
optimizer=‘sgd’, # use SGD to train
optimizer_params={‘learning_rate’:0.1}, # use fixed learning rate
eval_metric=‘acc’, # report accuracy during training
batch_end_callback = mx.callback.Speedometer(4, 100), # output progress for each 100 data batches
num_epoch=10) # train for at most 10 dataset passes
error like following:
<main.ConvGRUCell object at 0x2b122c5bbfd0>
[‘data’, ‘ConvGRU_i2h_weight’, ‘ConvGRU_i2h_bias’, ‘states’, ‘ConvGRU_h2h_weight’, ‘ConvGRU_h2h_bias’]
[]RuntimeErrorTraceback (most recent call last)
in ()
27 eval_metric=‘acc’, # report accuracy during training
28 batch_end_callback = mx.callback.Speedometer(4, 100), # output progress for each 100 data batches
—> 29 num_epoch=10) # train for at most 10 dataset passes
30 # model = mx.mod.Module(output, data_names=[‘data’,], label_names=None, context=mx.gpu())
31 print(model)/anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/module/base_module.pyc in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor)
458
459 self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
→ 460 for_training=True, force_rebind=force_rebind)
461 if monitor is not None:
462 self.install_monitor(monitor)anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/module/module.pyc in bind(self, data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req)
427 fixed_param_names=self._fixed_param_names,
428 grad_req=grad_req, group2ctxs=self._group2ctxs,
→ 429 state_names=self._state_names)
430 self._total_exec_bytes = self._exec_group._total_exec_bytes
431 if shared_module is not None:
/anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/module/executor_group.pyc in init(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group, logger, fixed_param_names, grad_req, state_names, group2ctxs)
262 self.num_outputs = len(self.symbol.list_outputs())
263
–> 264 self.bind_exec(data_shapes, label_shapes, shared_group)
265
266 def decide_slices(self, data_shapes):/anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/module/executor_group.pyc in bind_exec(self, data_shapes, label_shapes, shared_group, reshape)
358 else:
359 self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
→ 360 shared_group))
361
362 self.data_shapes = data_shapes
/anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/module/executor_group.pyc in _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group)
636 type_dict=input_types, shared_arg_names=self.param_names,
637 shared_exec=shared_exec, group2ctx=group2ctx,
–> 638 shared_buffer=shared_data_arrays, **input_shapes)
639 self._total_exec_bytes += int(executor.debug_str().split(’\n’)[-3].split()[1])
640 return executoranaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/symbol/symbol.pyc in simple_bind(self, ctx, grad_req, type_dict, stype_dict, group2ctx, shared_arg_names, shared_exec, shared_buffer, **kwargs)
1513 error_msg += “%s: %s\n” % (k, v)
1514 error_msg += “%s” % e
→ 1515 raise RuntimeError(error_msg)
1516
1517 # update shared_bufferRuntimeError: simple_bind error. Arguments:
data: (4, 3, 512, 512)
label: (4, 3, 512, 512)
Error in operator ConvGRU_t0_h2h: [20:55:03] src/operator/nn/./convolution-inl.h:625: Check failed: dtype != -1 (-1 vs. -1) First input must have specified typeStack trace returned 10 entries:
[bt] (0) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(_ZN4dmlc10StackTraceB5cxx11Ev+0x48) [0x2b11487cbc68]
[bt] (1) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x18) [0x2b11487cc678]
[bt] (2) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(ZNK5mxnet2op15ConvolutionProp9InferTypeEPSt6vectorIiSaIiEES5_S5+0x990) [0x2b1148953ab0]
[bt] (3) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(+0x2ed4735) [0x2b114aed4735]
[bt] (4) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(+0x2ccacf8) [0x2b114accacf8]
[bt] (5)/anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(+0x2cd1d61) [0x2b114acd1d61]
[bt] (6) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(_ZN5mxnet4exec9InferTypeEON4nnvm5GraphEOSt6vectorIiSaIiEERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE+0x11f) [0x2b114acd2bcf]
[bt] (7) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(_ZN5mxnet4exec13GraphExecutor4InitEN4nnvm6SymbolERKNS_7ContextERKSt3mapINSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEES4_St4lessISD_ESaISt4pairIKSD_S4_EEERKSt6vectorIS4_SaIS4_EESR_SR_RKSt13unordered_mapISD_NS2_6TShapeESt4hashISD_ESt8equal_toISD_ESaISG_ISH_ST_EEERKSS_ISD_iSV_SX_SaISG_ISH_iEEES17_RKSN_INS_9OpReqTypeESaIS18_EERKSt13unordered_setISD_SV_SX_SaISD_EEPSN_INS_7NDArrayESaIS1I_EES1L_S1L_PSS_ISD_S1I_SV_SX_SaISG_ISH_S1I_EEEPNS_8ExecutorERKSS_INS2_9NodeEntryES1I_NS2_13NodeEntryHashENS2_14NodeEntryEqualESaISG_IKS1S_S1I_EEE+0x7d5) [0x2b114acb68a5]
[bt] (8) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.so(ZN5mxnet8Executor10SimpleBindEN4nnvm6SymbolERKNS_7ContextERKSt3mapINSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEES3_St4lessISC_ESaISt4pairIKSC_S3_EEERKSt6vectorIS3_SaIS3_EESQ_SQ_RKSt13unordered_mapISC_NS1_6TShapeESt4hashISC_ESt8equal_toISC_ESaISF_ISG_SS_EEERKSR_ISC_iSU_SW_SaISF_ISG_iEEES16_RKSM_INS_9OpReqTypeESaIS17_EERKSt13unordered_setISC_SU_SW_SaISC_EEPSM_INS_7NDArrayESaIS1H_EES1K_S1K_PSR_ISC_S1H_SU_SW_SaISF_ISG_S1H_EEEPS0+0xcd) [0x2b114acb714d]
[bt] (9) /anaconda2/lib/python2.7/site-packages/mxnet-1.0.0-py2.7.egg/mxnet/libmxnet.s
Thanks!