Question about the clojure BERT example

I have a question regarding the clojure mxnet example using the BERT for sentence pair classification:

I’m not sure if I should post questions about a contrib package here, but I think the question is general enough that it applies to mxnet in any language.

In the example, the author uses the following code to build a classifier on top of BERT:

(defn fine-tune-model
  "msymbol: the pretrained network symbol
   num-classes: the number of classes for the fine-tune datasets
   dropout: The dropout rate amount"
  [msymbol {:keys [num-classes dropout]}]
  (as-> msymbol data
    (sym/dropout {:data data :p dropout})
    (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
(sym/softmax-output "softmax" {:data data})))

Basically, starting from the BERT-base model (in msymbol), it plugs a fully-connected classifier (a single-hidden-layer MLP, I assume?) on top of it.

My question is the following: Usually, the BERT model has an output embedding corresponding to the [CLS] token that should be used for classification, but in this case, I don’t see how the author is plugging the output of just this [CLS] token to the classifier. I think they are just plugging the whole output layer from BERT. Other packages have a built-in BertClassifier class of some kind that does this for you.

So, is there some magic going on in this example that I’m not seeing, or is the author just plugging the whole output layer of BERT into a classifier (when typically, you should just use the CLS token)?

Thank you very much

Hi @setzer22,

Yes, you’re correct that using the [CLS] token was recommended for classification tasks, but this assumes you’re fine-tuning the whole network (the last transformer layer at the very least). I’m not familiar with Clojure so can’t tell on first glance, but if only the head network is being trained here, it would be better to use all outputs instead of just the [CLS] token.

Just checked GluonNLP and I think the slice might be included in the symbol file being used, rather than being added in Clojure. Assuming the output is in NTC format the first time step is being sliced to give the [CLS] token output. Code I used to general a symbol file:

git clone
cd gluon-nlp
pip install -e .
pip install mxnet --pre
cd scripts/bert/
python --task classification --output_dir folder/goes/here --seq_length 128

Thanks! This makes a lot of sense