How to pass static vector into MakeLoss?


#1

I would like to pass a vector of the same length as the label into the loss function, and use this to vector to compute the loss with respect to a particular subset of the observations.

Feel free to skip forward to minimum reproducible code, or if you’d like some background…

I have a highly skewed binary classification outcome - the problem here can be considered ‘yield’. Given the data available at the time of an initial observation, what is the likelihood they will eventually yield? That said, there is a gatekeeper funnel that an observation must pass through before they can yield. Not every observation that makes it through the funnel yields, but it is incredibly useful to the model. Since the prediction will be made prior to the funnel, I cannot include the funnel as a feature. I would like to penalize the model predicting ‘yield’ in the post-funnel population when the observation did not yield.

Here is the (errant) code that roughly outlines what I’m trying to do. The syntax works all the way through until the last line when it tries to make a prediction. Basically, I’m looking for a way to make the section of the loss function that contains ‘indicatorSym’ work appropriately.

I’m relatively new to this framework, so please forgive any lack of basic knowledge. Any help would be greatly appreciated!

library(mxnet)

set.seed(123)
df <- data.frame(a=rnorm(100),
                 b=rnorm(100),
                 c=rnorm(100),
                 y=sample(c(0,1),size=100,prob=c(.92,.08),replace=TRUE))

#Indicator is always 1 if y == 1, otherwise could be 1 or 0
indicator <- apply(as.matrix(df$y),MARGIN =1,FUN=function(y){
          ifelse(y==1,1,sample(c(0,1),size=1,prob=c(.8,.2)))}) 

#Define Train/Test Sets
test.ind = seq(1, length(df$y), 5)    # 1 pt in 5 used for testing
train.x = data.matrix(df[-test.ind,-4])
train.y = df[-test.ind, 4]
test.x = data.matrix(df[test.ind,-4])
test.y = df[test.ind, 4]

#NN Architecture
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
indicatorSym <- mx.symbol.Variable("indicatorSym")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")
sigOut <- mx.symbol.Activation(fc2, act_type = "sigmoid", name = "sigOut")

loss <- mx.symbol.MakeLoss(mx.symbol.negative(
                              ### Cross Entropy
                              label*mx.symbol.log(mx.symbol.Reshape(sig, shape = 0)) +
                               (1-label)*mx.symbol.log(1-mx.symbol.Reshape(sig, shape = 0)) + 
                               ### Add'l cost constraint using 'indicatorSym'...
                               ### Commenting out next line should return the same output as if I used mx.symbol.LogisticRegressionOutput
                                indicatorSym*(1-label)*mx.symbol.log(1-mx.symbol.Reshape(sigOut, shape = 0 ))
                               ),name="loss")

#Build Model
mx.set.seed(0)
model <- mx.model.FeedForward.create(loss, X = train.x, y = train.y,
                                      ctx = mx.cpu(),
                                      num.round = 5,
                                      array.batch.size = 10,
                                      optimizer = "rmsprop",
                                      verbose = TRUE,
                                      array.layout = "rowmajor",
                                      batch.end.callback = NULL,
                                      epoch.end.callback = NULL,
                                     {"indicatorSym"= indicator}
                                     )

#Produce output
internals = internals(model$symbol)
fc_symbol = internals[[match("sigOut_output", outputs(internals))]]
model2 <- list(symbol = fc_symbol,
               arg.params = model$arg.params,
               aux.params = model$aux.params)
class(model2) <- "MXFeedForwardModel"
#Produces error here...
predCustom <- predict(model2,test.x)