BatchNormalise, C++, am I doing something wrong?


#1

Hi,

I’m dong something for which I want to use the BatchNormalise operator with global statistics enabled. However, after having some trouble with it, I created a test to try and make sure I was using the operator correctly.

That test is causing me issues, so I was hoping someone would be able to check it over, and tell me if I’m doing something wrong, or if there’s a problem with BatchNormalise.

The test consists of generating a set of 2D tensors, the first axis being the batchsize, the second being the data per batch. My computation graph consists of just the BatchNormalise and all I want to see is that I get some sort of sensible output, and that the global mean and variance are being learnt. Very quickly however, the output starts to have nan values.

I’m using C++, and have checked out against tag 1.2.0.

My test code is:

// MXNet
#include "mxnet-cpp/MxNetCpp.h"
#include <iostream>
#include <chrono>
using std::cout;
using std::endl;

namespace mx = mxnet::cpp;

#define BATCHSIZE 2
#define VECSIZE 8

int main(void)
{
	mx::Context ctx = mx::Context::cpu();    // set the context we're using, whether gpu or cpu.
	std::map< std::string, mx::NDArray > args;
	
	// Test functionality of BatchNorm.
	//
	// Provide some input
	args["Input"]    = mx::NDArray( mx::Shape( BATCHSIZE, VECSIZE ), ctx );
	
	
	// The computation graph will be just the batch norm of the input data.
	// We want to learn and use the global statistics.
	auto input = mx::Symbol::Variable("Input");
	auto bnGamma  = mx::Symbol::Variable( "batchNorm-Gamma");
	auto bnBeta   = mx::Symbol::Variable( "batchNorm-Beta");
	auto bnMean   = mx::Symbol::Variable( "batchNorm-Mean");
	auto bnVar    = mx::Symbol::Variable( "batchNorm-Var");
	float eps = 0.01;
	float momentum = 0.9;
	bool fixGamma = false;
	bool useGlobal = true;
	bool output    = true;
	auto net = mx::BatchNorm( "batchNorm", input, bnGamma, bnBeta, bnMean, bnVar, eps, momentum, fixGamma, useGlobal, output );
	
	
	
	// Need an optimiser for learning.
	mx::Optimizer* opt = mx::OptimizerRegistry::Find("adam");
	
	
	// Set up our "network"
	net.InferArgsMap(ctx, &args, args);
	auto *exec = net.SimpleBind(ctx, args);
	
	
	// initialise gamma and beta
	auto argNames = net.ListArguments();
	auto auxNames = net.ListAuxiliaryStates();
	
	
	auto uniformInitialiser = mx::Uniform( 1.0 );
	uniformInitialiser( "batchNorm-Gamma", &args[ "batchNorm-Gamma" ] );
	uniformInitialiser( "batchNorm-Beta", &args[ "batchNorm-Beta" ] );
	    	
	
	// check names and sizes before we start.
	cout << "args:" << endl;
	for( auto ai = argNames.begin(); ai != argNames.end(); ++ai )
	{
		auto shape = args[ *ai ].GetShape();
		cout << "\t" << *ai << " ";
		for( unsigned sc = 0; sc < shape.size(); ++sc )
		{
			cout << shape[sc] << " ";
		}
		cout << endl;
	}
	cout << "auxs:" << endl;
	for( auto ai = auxNames.begin(); ai != auxNames.end(); ++ai )
	{
		auto shape = args[ *ai ].GetShape();
		cout << "\t" << *ai;
		for( unsigned sc = 0; sc < shape.size(); ++sc )
		{
			cout << shape[sc] << " ";
		}
		cout << endl;
	}
	  	
	
	// do some batches...
	for( unsigned bc = 0; bc < 10; ++bc )
	{
		
		// create some random data as our input.
		std::vector<float> data(BATCHSIZE*VECSIZE);
		int indx = 0;
		cout << "batch:" << endl << "\t";
		for( unsigned bc = 0; bc < BATCHSIZE; ++bc )
		{
			for( unsigned vc = 0; vc < VECSIZE; ++vc )
			{
				data[indx] = rand()%10;
				cout << data[indx] << " ";
				++indx;
			}
			cout << endl << "\t";
		}
		cout << endl;
		
		mx::NDArray dataArray( data, mx::Shape( BATCHSIZE, VECSIZE ), mx::Context::cpu() );
		dataArray.CopyTo( &args["Input"] );
		
		
		
		
		// do the forward and backward passes.
		exec->Forward(false);
		exec->Backward();
		
		
		
		// update...
		for( unsigned ac = 0; ac < argNames.size(); ++ac )
		{
			if( argNames[ac].compare("Input") != 0 )
				opt->Update(ac, exec->arg_arrays[ac], exec->grad_arrays[ac]);
		}
		
		
		//
		// print out current values of things....
		//
		
		// first, the output...
		auto outShape = exec->outputs[0].GetShape();
		cout << "output: ";
		for( unsigned sc = 0; sc < outShape.size(); ++sc )
		{
			cout << outShape[sc] << " ";
		}
		cout << endl << "\t";
		float *outData = (float*)exec->outputs[0].GetData();
		for( unsigned c0 = 0; c0 < outShape[0]; ++c0 )
		{
			for( unsigned c1 = 0; c1 < outShape[1]; ++c1 )
			{
				cout << *outData << " ";
				++outData;
			}
			cout << endl << "\t";
		}
		cout << endl;
		
		// Gamma
		auto gammaShape = args["batchNorm-Gamma"].GetShape();
		cout << "Gamma: ";
		for( unsigned sc = 0; sc < gammaShape.size(); ++sc )
		{
			cout << gammaShape[sc] << " ";
		}
		cout << endl << "\t";
		float *gamData = (float*)args["batchNorm-Gamma"].GetData();
		for( unsigned c0 = 0; c0 < gammaShape[0]; ++c0 )
		{
			cout << *gamData << " ";
			++outData;
		}
		cout << endl;
		
		
		// Beta
		auto betaShape = args["batchNorm-Beta"].GetShape();
		cout << "Beta: ";
		for( unsigned sc = 0; sc < betaShape.size(); ++sc )
		{
			cout << betaShape[sc] << " ";
		}
		cout << endl << "\t";
		float *betaData = (float*)args["batchNorm-Beta"].GetData();
		for( unsigned c0 = 0; c0 < betaShape[0]; ++c0 )
		{
			cout << *betaData << " ";
			++betaData;
		}
		cout << endl;
		
		
		
		// Mean
		auto auxDict = exec->aux_dict();
		auto meanShape = auxDict["batchNorm-Mean"].GetShape();
		cout << "Mean: ";
		for( unsigned sc = 0; sc < meanShape.size(); ++sc )
		{
			cout << meanShape[sc] << " ";
		}
		cout << endl << "\t";
		float *meanData = (float*)auxDict["batchNorm-Mean"].GetData();
		for( unsigned c0 = 0; c0 < meanShape[0]; ++c0 )
		{
			cout << *meanData << " ";
			++meanData;
		}
		cout << endl;
		
		
		auto varShape = auxDict["batchNorm-Var"].GetShape();
		cout << "Var: ";
		for( unsigned sc = 0; sc < varShape.size(); ++sc )
		{
			cout << varShape[sc] << " ";
		}
		cout << endl << "\t";
		float *varData = (float*)auxDict["batchNorm-Var"].GetData();
		for( unsigned c0 = 0; c0 < varShape[0]; ++c0 )
		{
			cout << *varData << " ";
			++varData;
		}
		cout << endl;
		
		
		cout << " ------  end of batch ---------- " << endl;
		
		
		
	}

}

An example of the output:

args:
Input 2 8
batchNorm-Gamma 8
batchNorm-Beta 8
auxs:
batchNorm-Mean
batchNorm-Var
batch:
7 9 3 8 0 2 4 8
3 9 0 5 2 2 7 3

output: 2 8
nan 3.15909 nan 14.421 -1.7852 -3.6966 -3.54867 7.63761
nan 3.15909 nan 9.16948 -0.467756 -3.6966 -5.68227 4.24182

Gamma: 8
-0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984
Beta: 8
0.333533 0.30628 0.341276 -0.658181 -0.579235 -0.283696 -0.742147 0.501372
Mean: 8
0.571268 -2.75796 1.07628 -0.614133 1.83076 -1.14681 0.053838 -2.50748
Var: 8
-0.59165 0.858605 -0.227942 0.201315 0.350055 0.536052 1.51944 1.90409
------ end of batch ----------
batch:
7 9 0 2 3 9 9 7
0 3 9 8 6 5 7 6

output: 2 8
nan 3.15851 nan 3.91765 0.190964 -11.2885 -7.10468 6.95845
nan 1.70304 nan 14.4202 2.16713 -6.95028 -5.68227 6.27929

Gamma: 8
-0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984 -0.280984
Beta: 8
0.333533 0.30628 0.341276 -0.658181 -0.579235 -0.283696 -0.742147 0.501372
Mean: 8
0.571268 -2.75796 1.07628 -0.614133 1.83076 -1.14681 0.053838 -2.50748
Var: 8
-0.59165 0.858605 -0.227942 0.201315 0.350055 0.536052 1.51944 1.90409
------ end of batch ----------