Hi @thomelane,
Thanks for your help. When I use the slice_axis operator I got some errors:
test.csv
1,2,1
1,3,0
1,4,0
1,5,0
2,3,1
2,2,0
2,3,0
The first column is member id, the second column is movie id, the third column is class label
def simple_mf_net(n_users, n_items):
data = mx.symbol.Variable("data")
# used to be mx.symbol.Variable("softmax_label")
y_true = mx.symbol.slice_axis(data, axis=1, begin=2, end=3)
# used to be mx.symbol.Variable("member")
user = mx.symbol.slice_axis(data, axis=1, begin=0, end=1)
user = mx.symbol.Embedding(name='member_embedding',
data=user, input_dim=n_users, output_dim=64)
# used to be mx.symbol.Variable("movie")
movie = mx.symbol.slice_axis(data, axis=1, begin=1, end=2)
movie = mx.symbol.Embedding(name='movie_embedding',
data=movie, input_dim=n_items, output_dim=64)
dot = user * movie
dot = mx.symbol.sum_axis(dot, axis=1)
dot = mx.symbol.Flatten(dot)
return mx.symbol.LinearRegressionOutput(data=dot, label=y_true)
n_users = 2
n_movies = 4
model = mx.module.Module(context=[mx.cpu()],
symbol=simple_mf_net(n_users, n_movies))
csv_iter = mx.io.CSVIter(data_csv = 'test.csv', data_shape = (3,), batch_size = 7)
model.fit(csv_iter,
num_epoch=1,
eval_metric=['rmse'],
optimizer='adam')
I got the following errors:
...
RuntimeError: simple_bind error. Arguments:
data: (7, 3)
softmax_label: (7,)
Error in operator linearregressionoutput17: Shape inconsistent, Provided=[7,1], inferred shape=[7,64]
I’m not sure which operator I should be using here: slice_axis or slice? Basically I’d like
the ‘user’ variable uses the first column in the csv file
the ‘movie’ variable uses the second column
the ‘y_true’ variable uses the third column
Thanks!