I have a question regarding the clojure mxnet example using the BERT for sentence pair classification: https://github.com/apache/incubator-mxnet/tree/master/contrib/clojure-package/examples/bert
I’m not sure if I should post questions about a contrib package here, but I think the question is general enough that it applies to mxnet in any language.
In the example, the author uses the following code to build a classifier on top of BERT:
(defn fine-tune-model
"msymbol: the pretrained network symbol
num-classes: the number of classes for the fine-tune datasets
dropout: The dropout rate amount"
[msymbol {:keys [num-classes dropout]}]
(as-> msymbol data
(sym/dropout {:data data :p dropout})
(sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
(sym/softmax-output "softmax" {:data data})))
Basically, starting from the BERT-base model (in msymbol), it plugs a fully-connected classifier (a single-hidden-layer MLP, I assume?) on top of it.
My question is the following: Usually, the BERT model has an output embedding corresponding to the [CLS] token that should be used for classification, but in this case, I don’t see how the author is plugging the output of just this [CLS] token to the classifier. I think they are just plugging the whole output layer from BERT. Other packages have a built-in BertClassifier class of some kind that does this for you.
So, is there some magic going on in this example that I’m not seeing, or is the author just plugging the whole output layer of BERT into a classifier (when typically, you should just use the CLS token)?
Thank you very much