Ok, here’s the notebook:
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Matrix Factorization\n",
"\n",
"In a recommendation system, there is a group of users and a set of items. Given that each users have rated some items in the system, we would like to predict how the users would rate the items that they have not yet rated, such that we can make recommendations to the users.\n",
"\n",
"Matrix factorization is one of the mainly used algorithm in recommendation systems. It can be used to discover latent features underlying the interactions between two different kinds of entities.\n",
"\n",
"Assume we assign a k-dimensional vector to each user and a k-dimensional vector to each item such that the dot product of these two vectors gives the user's rating of that item. We can learn the user and item vectors directly, which is essentially performing SVD on the user-item matrix. We can also try to learn the latent features using multi-layer neural networks. \n",
"\n",
"In this tutorial, we will work though the steps to implement these ideas in MXNet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Data\n",
"\n",
"We use the [MovieLens](http://grouplens.org/datasets/movielens/) data here, but it can apply to other datasets as well. Each row of this dataset contains a tuple of user id, movie id, rating, and time stamp, we will only use the first three items. We first define the a batch which contains n tuples. It also provides name and shape information to MXNet about the data and label. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"class Batch(object):\n",
" def __init__(self, data_names, data, label_names, label):\n",
" self.data = data\n",
" self.label = label\n",
" self.data_names = data_names\n",
" self.label_names = label_names\n",
" \n",
" @property\n",
" def provide_data(self):\n",
" return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n",
" \n",
" @property\n",
" def provide_label(self):\n",
" return [(n, x.shape) for n, x in zip(self.label_names, self.label)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we define a data iterator, which returns a batch of tuples each time. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import mxnet as mx\n",
"import random\n",
"\n",
"class Batch(object):\n",
" def __init__(self, data_names, data, label_names, label):\n",
" self.data = data\n",
" self.label = label\n",
" self.data_names = data_names\n",
" self.label_names = label_names\n",
"\n",
" @property\n",
" def provide_data(self):\n",
" return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n",
"\n",
" @property\n",
" def provide_label(self):\n",
" return [(n, x.shape) for n, x in zip(self.label_names, self.label)]\n",
"\n",
"class DataIter(mx.io.DataIter):\n",
" def __init__(self, fname, batch_size):\n",
" super(DataIter, self).__init__()\n",
" self.batch_size = batch_size\n",
" self.data = []\n",
" for line in file(fname):\n",
" tks = line.strip().split('\\t')\n",
" if len(tks) != 4:\n",
" continue\n",
" self.data.append((int(tks[0]), int(tks[1]), float(tks[2])))\n",
" self.provide_data = [('user', (batch_size, )), ('item', (batch_size, ))]\n",
" self.provide_label = [('score', (self.batch_size, ))]\n",
"\n",
" def __iter__(self):\n",
" for k in range(len(self.data) / self.batch_size):\n",
" users = []\n",
" items = []\n",
" scores = []\n",
" for i in range(self.batch_size):\n",
" j = k * self.batch_size + i\n",
" user, item, score = self.data[j]\n",
" users.append(user)\n",
" items.append(item)\n",
" scores.append(score)\n",
"\n",
" data_all = [mx.nd.array(users), mx.nd.array(items)]\n",
" label_all = [mx.nd.array(scores)]\n",
" data_names = ['user', 'item']\n",
" label_names = ['score']\n",
"\n",
" data_batch = Batch(data_names, data_all, label_names, label_all)\n",
" yield data_batch\n",
"\n",
" def reset(self):\n",
" random.shuffle(self.data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we download the data and provide a function to obtain the data iterator:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import urllib\n",
"import zipfile\n",
"if not os.path.exists('ml-100k.zip'):\n",
" urllib.urlretrieve('http://files.grouplens.org/datasets/movielens/ml-100k.zip', 'ml-100k.zip')\n",
"with zipfile.ZipFile(\"ml-100k.zip\",\"r\") as f:\n",
" f.extractall(\"./\")\n",
"def get_data(batch_size):\n",
" return (DataIter('./ml-100k/u1.base', batch_size), DataIter('./ml-100k/u1.test', batch_size))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we calculate the numbers of users and items for later use."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(944, 1683)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def max_id(fname):\n",
" mu = 0\n",
" mi = 0\n",
" for line in file(fname):\n",
" tks = line.strip().split('\\t')\n",
" if len(tks) != 4:\n",
" continue\n",
" mu = max(mu, int(tks[0]))\n",
" mi = max(mi, int(tks[1]))\n",
" return mu + 1, mi + 1\n",
"max_user, max_item = max_id('./ml-100k/u.data')\n",
"(max_user, max_item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimization\n",
"\n",
"We first implement the RMSE (root-mean-square error) measurement, which is commonly used by matrix factorization. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"def RMSE(label, pred):\n",
" ret = 0.0\n",
" n = 0.0\n",
" pred = pred.flatten()\n",
" for i in range(len(label)):\n",
" ret += (label[i] - pred[i]) * (label[i] - pred[i])\n",
" n += 1.0\n",
" return math.sqrt(ret / n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we define a general training module, which is borrowed from the image classification application. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def train(network, batch_size, num_epoch, learning_rate):\n",
" batch_size = batch_size\n",
" train, test = get_data(batch_size)\n",
" model = mx.mod.Module(symbol = network, \n",
" data_names=[x[0] for x in train.provide_data],\n",
" label_names=[y[0] for y in train.provide_label],\n",
" context=[mx.cpu()])\n",
"\n",
" import logging\n",
" head = '%(asctime)-15s %(message)s'\n",
" logging.basicConfig(level=logging.DEBUG)\n",
"\n",
" model.fit(train_data = train, \n",
" eval_data = test,\n",
" num_epoch=num_epoch,\n",
" optimizer='sgd',\n",
" optimizer_params={'learning_rate':learning_rate, 'momentum':0.9, 'wd':0.0001},\n",
" eval_metric = RMSE,\n",
" batch_end_callback=mx.callback.Speedometer(batch_size, 20000/batch_size),)\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def predict(net, users, items, batch_size):\n",
" assert users.shape == items.shape and users.ndim == 1\n",
" data = {'user':users, 'item':items}\n",
" data_iter = mx.io.NDArrayIter(data, batch_size=batch_size)\n",
" return net.predict(data_iter)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Networks\n",
"\n",
"Now we try various networks. We first learn the latent vectors directly."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Epoch[0] Batch [312]\tSpeed: 67563.78 samples/sec\tRMSE=3.721034\n",
"INFO:root:Epoch[0] Batch [624]\tSpeed: 64852.64 samples/sec\tRMSE=3.682020\n",
"INFO:root:Epoch[0] Batch [936]\tSpeed: 62305.78 samples/sec\tRMSE=3.625233\n",
"INFO:root:Epoch[0] Batch [1248]\tSpeed: 68365.30 samples/sec\tRMSE=3.690150\n",
"INFO:root:Epoch[0] Train-RMSE=3.411721\n",
"INFO:root:Epoch[0] Time cost=1.227\n",
"INFO:root:Epoch[0] Validation-RMSE=3.717236\n",
"INFO:root:Epoch[1] Batch [312]\tSpeed: 64842.75 samples/sec\tRMSE=3.695260\n",
"INFO:root:Epoch[1] Batch [624]\tSpeed: 66291.53 samples/sec\tRMSE=3.687865\n",
"INFO:root:Epoch[1] Batch [936]\tSpeed: 62540.87 samples/sec\tRMSE=3.656047\n",
"INFO:root:Epoch[1] Batch [1248]\tSpeed: 67802.81 samples/sec\tRMSE=3.422298\n",
"INFO:root:Epoch[1] Train-RMSE=3.140008\n",
"INFO:root:Epoch[1] Time cost=1.233\n",
"INFO:root:Epoch[1] Validation-RMSE=3.310249\n",
"INFO:root:Epoch[2] Batch [312]\tSpeed: 64824.02 samples/sec\tRMSE=2.766653\n",
"INFO:root:Epoch[2] Batch [624]\tSpeed: 64551.38 samples/sec\tRMSE=2.157919\n",
"INFO:root:Epoch[2] Batch [936]\tSpeed: 63794.90 samples/sec\tRMSE=1.791587\n",
"INFO:root:Epoch[2] Batch [1248]\tSpeed: 67786.24 samples/sec\tRMSE=1.556971\n",
"INFO:root:Epoch[2] Train-RMSE=1.483133\n",
"INFO:root:Epoch[2] Time cost=1.235\n",
"INFO:root:Epoch[2] Validation-RMSE=1.589926\n",
"INFO:root:Epoch[3] Batch [312]\tSpeed: 64572.03 samples/sec\tRMSE=1.388250\n",
"INFO:root:Epoch[3] Batch [624]\tSpeed: 67723.48 samples/sec\tRMSE=1.295879\n",
"INFO:root:Epoch[3] Batch [936]\tSpeed: 61854.14 samples/sec\tRMSE=1.233358\n",
"INFO:root:Epoch[3] Batch [1248]\tSpeed: 62523.69 samples/sec\tRMSE=1.189739\n",
"INFO:root:Epoch[3] Train-RMSE=1.164325\n",
"INFO:root:Epoch[3] Time cost=1.256\n",
"INFO:root:Epoch[3] Validation-RMSE=1.252920\n",
"INFO:root:Epoch[4] Batch [312]\tSpeed: 58991.07 samples/sec\tRMSE=1.133154\n",
"INFO:root:Epoch[4] Batch [624]\tSpeed: 58834.63 samples/sec\tRMSE=1.108365\n",
"INFO:root:Epoch[4] Batch [936]\tSpeed: 51809.62 samples/sec\tRMSE=1.092431\n",
"INFO:root:Epoch[4] Batch [1248]\tSpeed: 68050.28 samples/sec\tRMSE=1.077801\n",
"INFO:root:Epoch[4] Train-RMSE=1.158075\n",
"INFO:root:Epoch[4] Time cost=1.367\n",
"INFO:root:Epoch[4] Validation-RMSE=1.136881\n",
"INFO:root:Epoch[5] Batch [312]\tSpeed: 63840.18 samples/sec\tRMSE=1.045013\n",
"INFO:root:Epoch[5] Batch [624]\tSpeed: 54559.02 samples/sec\tRMSE=1.041145\n",
"INFO:root:Epoch[5] Batch [936]\tSpeed: 56593.49 samples/sec\tRMSE=1.044665\n",
"INFO:root:Epoch[5] Batch [1248]\tSpeed: 48774.98 samples/sec\tRMSE=1.029294\n",
"INFO:root:Epoch[5] Train-RMSE=1.075923\n",
"INFO:root:Epoch[5] Time cost=1.450\n",
"INFO:root:Epoch[5] Validation-RMSE=1.084058\n",
"INFO:root:Epoch[6] Batch [312]\tSpeed: 57131.73 samples/sec\tRMSE=1.005409\n",
"INFO:root:Epoch[6] Batch [624]\tSpeed: 52827.36 samples/sec\tRMSE=1.016771\n",
"INFO:root:Epoch[6] Batch [936]\tSpeed: 62829.42 samples/sec\tRMSE=1.013194\n",
"INFO:root:Epoch[6] Batch [1248]\tSpeed: 55852.82 samples/sec\tRMSE=1.006081\n",
"INFO:root:Epoch[6] Train-RMSE=1.063851\n",
"INFO:root:Epoch[6] Time cost=1.412\n",
"INFO:root:Epoch[6] Validation-RMSE=1.054935\n",
"INFO:root:Epoch[7] Batch [312]\tSpeed: 67856.54 samples/sec\tRMSE=0.987511\n",
"INFO:root:Epoch[7] Batch [624]\tSpeed: 63437.23 samples/sec\tRMSE=0.997116\n",
"INFO:root:Epoch[7] Batch [936]\tSpeed: 67572.45 samples/sec\tRMSE=0.991542\n",
"INFO:root:Epoch[7] Batch [1248]\tSpeed: 63492.25 samples/sec\tRMSE=0.995827\n",
"INFO:root:Epoch[7] Train-RMSE=0.940500\n",
"INFO:root:Epoch[7] Time cost=1.228\n",
"INFO:root:Epoch[7] Validation-RMSE=1.038229\n",
"INFO:root:Epoch[8] Batch [312]\tSpeed: 65570.33 samples/sec\tRMSE=0.982947\n",
"INFO:root:Epoch[8] Batch [624]\tSpeed: 63110.60 samples/sec\tRMSE=0.979931\n",
"INFO:root:Epoch[8] Batch [936]\tSpeed: 64586.87 samples/sec\tRMSE=0.982295\n",
"INFO:root:Epoch[8] Batch [1248]\tSpeed: 53426.81 samples/sec\tRMSE=0.982971\n",
"INFO:root:Epoch[8] Train-RMSE=0.953352\n",
"INFO:root:Epoch[8] Time cost=1.313\n",
"INFO:root:Epoch[8] Validation-RMSE=1.028586\n",
"INFO:root:Epoch[9] Batch [312]\tSpeed: 59365.85 samples/sec\tRMSE=0.963168\n",
"INFO:root:Epoch[9] Batch [624]\tSpeed: 66323.45 samples/sec\tRMSE=0.976033\n",
"INFO:root:Epoch[9] Batch [936]\tSpeed: 69935.30 samples/sec\tRMSE=0.982855\n",
"INFO:root:Epoch[9] Batch [1248]\tSpeed: 66512.54 samples/sec\tRMSE=0.974915\n",
"INFO:root:Epoch[9] Train-RMSE=0.923108\n",
"INFO:root:Epoch[9] Time cost=1.233\n",
"INFO:root:Epoch[9] Validation-RMSE=1.019939\n"
]
}
],
"source": [
"# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n",
"def plain_net(k):\n",
" # input\n",
" user = mx.symbol.Variable('user')\n",
" item = mx.symbol.Variable('item')\n",
" score = mx.symbol.Variable('score')\n",
" # user feature lookup\n",
" user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k) \n",
" # item feature lookup\n",
" item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n",
" # predict by the inner product, which is elementwise product and then sum\n",
" pred = user * item\n",
" pred = mx.symbol.sum_axis(data = pred, axis = 1)\n",
" pred = mx.symbol.Flatten(data = pred)\n",
" # loss layer\n",
" pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n",
" return pred\n",
"\n",
"model = train(plain_net(64), batch_size=64, num_epoch=10, learning_rate=.05)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.py:65: UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['score'])\n",
" warnings.warn(msg)\n"
]
},
{
"ename": "AssertionError",
"evalue": "Shape of unspecified array arg:score changed. This can cause the new executor to not share parameters with the old one. Please check for error in network.If this is intended, set partial_shaping=True to suppress this warning.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-c6918a3c3dca>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0musers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_user\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mitems\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_item\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m batch_size=4))\n\u001b[0m",
"\u001b[0;32m<ipython-input-13-3d209d85c5ce>\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(net, users, items, batch_size)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'user'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0musers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'item'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdata_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNDArrayIter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.pyc\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, eval_data, num_batch, merge_batches, reset, always_output_list)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnum_batch\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnbatch\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mnum_batch\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 352\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 353\u001b[0m \u001b[0mpad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_batch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mout\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_outputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.pyc\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, data_batch, is_train)\u001b[0m\n\u001b[1;32m 605\u001b[0m \u001b[0mnew_lshape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 607\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_dshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_lshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 608\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exec_group\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.pyc\u001b[0m in \u001b[0;36mreshape\u001b[0;34m(self, data_shapes, label_shapes)\u001b[0m\n\u001b[1;32m 467\u001b[0m self.data_names, self.label_names, data_shapes, label_shapes)\n\u001b[1;32m 468\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exec_group\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_shapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_label_shapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 471\u001b[0m def init_optimizer(self, kvstore='local', optimizer='sgd',\n",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/executor_group.pyc\u001b[0m in \u001b[0;36mreshape\u001b[0;34m(self, data_shapes, label_shapes)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_default_execs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_default_execs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_exec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_shapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_shapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mset_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_extra\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/executor_group.pyc\u001b[0m in \u001b[0;36mbind_exec\u001b[0;34m(self, data_shapes, label_shapes, shared_group, reshape)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreshape\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 329\u001b[0m self.execs[i] = self._default_execs[i].reshape(\n\u001b[0;32m--> 330\u001b[0;31m allow_up_sizing=True, **dict(data_shapes_i + label_shapes_i))\n\u001b[0m\u001b[1;32m 331\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 332\u001b[0m self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,\n",
"\u001b[0;32m/path/to/.local/share/virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/executor.pyc\u001b[0m in \u001b[0;36mreshape\u001b[0;34m(self, partial_shaping, allow_up_sizing, **kwargs)\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;34m\"This can cause the new executor to not share parameters \"\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[0;34m\"with the old one. Please check for error in network.\"\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 430\u001b[0;31m \"If this is intended, set partial_shaping=True to suppress this warning.\")\n\u001b[0m\u001b[1;32m 431\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0mnew_aux_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: Shape of unspecified array arg:score changed. This can cause the new executor to not share parameters with the old one. Please check for error in network.If this is intended, set partial_shaping=True to suppress this warning."
]
}
],
"source": [
"import numpy as np\n",
"\n",
"print(predict(\n",
" model,\n",
" users=np.random.randint(0, max_user, size=(16,)),\n",
" items=np.random.randint(0, max_item, size=(16,)),\n",
" batch_size=4))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Thanks