You are on page 1of 7

DocBERT: BERT for Document Classification

Ashutosh Adhikari, Achyudh Ram, Raphael Tang, and Jimmy Lin


David R. Cheriton School of Computer Science
University of Waterloo
{adadhika, arkeshav, r33tang, jimmylin}@uwaterloo.ca

Abstract an unsupervised objective of masked language


modeling and next-sentence prediction. Next, this
We present, to our knowledge, the first ap- pre-trained network is then fine-tuned on task-
arXiv:1904.08398v3 [cs.CL] 22 Aug 2019

plication of BERT to document classification.


specific, labeled data.
A few characteristics of the task might lead
one to think that BERT is not the most appro- BERT, however, has not yet been fine-tuned for
priate model: syntactic structures matter less document classification. Why is this worth ex-
for content categories, documents can often ploring? For one, modeling syntactic structure has
be longer than typical BERT input, and doc- been arguably less important for document classi-
uments often have multiple labels. Neverthe- fication than for typical BERT tasks such as nat-
less, we show that a straightforward classifi- ural language inference and paraphrasing. This
cation model using BERT is able to achieve claim is supported by our observation that logistic
the state of the art across four popular datasets.
To address the computational expense associ-
regression and support vector machines are excep-
ated with BERT inference, we distill knowl- tionally strong document classification baselines.
edge from BERTlarge to small bidirectional For another, documents often have multiple labels
LSTMs, reaching BERTbase parity on multi- across many classes, which is again uncharacteris-
ple datasets using 30× fewer parameters. The tic of the tasks that BERT examines.
primary contribution of our paper is improved In this paper, we first describe fine-tuning BERT
baselines that can provide the foundation for for document classification to establish state-of-
future work.
the-art results on four popular datasets. This in-
crease in model quality, however, comes at a heavy
1 Introduction
computational expense. BERT contains hundreds
Until recently, the dominant paradigm in ap- of millions of parameters, while the previous base-
proaching natural language processing (NLP) line uses less than four million and performs infer-
tasks has been to concentrate on neural architec- ence forty times faster.
ture design, using only task-specific data and word To alleviate this computational burden, we ap-
embeddings such as GloVe (Pennington et al., ply knowledge distillation (Hinton et al., 2015)
2014). The NLP community is, however, wit- to transfer knowledge from BERTlarge , the large
nessing a dramatic paradigm shift toward the BERT variant, to the previous, much smaller state-
pre-trained deep language representation model, of-the-art BiLSTM. As a result of this proce-
which achieves the state of the art in question an- dure, with a few additional tricks for effective
swering, sentiment classification, and similarity knowledge transfer, we achieve results compara-
modeling, to name a few. Bidirectional Encoder ble to BERTbase , the smaller BERT variant, using
Representations from Transformers (BERT; De- a model with 30× fewer parameters.
vlin et al., 2019) represents one of the latest de- Our contributions in this paper are two fold:
velopments in this line of work. It outperforms First, we establish state-of-the-art results for doc-
its predecessors, ELMo (Peters et al., 2018) and ument classification by simply fine-tuning BERT;
GPT (Radford et al., 2018), by a wide margin on Second, we demonstrate that BERT can be dis-
multiple NLP tasks. tilled into a much simpler neural model that pro-
This approach consists of two stages: first, vides competitive accuracy at a far more modest
BERT is pre-trained on vast amounts of text, with computational cost.
2 Background and Related Work architectures, unlike most of the other model com-
pression techniques.
Over the last few years, neural network-based ar-
chitectures have dominated the task of document 3 Our Approach
classification. Liu et al. (2017a) develop XML-
CNN for addressing this problem’s multi-label na- To adapt BERTbase and BERTlarge models for
ture, which they call extreme classification. XML- document classification, we follow Devlin et al.
CNN is based on the popular KimCNN (Kim, (2019) and introduce a fully-connected layer over
2014), except with wider convolutional filters, the final hidden state corresponding to the [CLS]
adaptive dynamic max-pooling (Chen et al., 2015; input token. During fine-tuning, we optimize the
Johnson and Zhang, 2015), and an additional bot- entire model end-to-end, with the additional soft-
tleneck layer to better capture the features of large max classifier parameters W ∈ IRK×H , where H
documents. Another popular model, Hierarchical is the dimension of the hidden state vectors and K
Attention Network (HAN; Yang et al., 2016) ex- is the number of classes. We minimize the cross-
plicitly models hierarchical information from doc- entropy and binary cross-entropy loss for single-
uments to extract meaningful features, incorporat- label and multi-label tasks, respectively.
ing word- and sentence-level encoders (with atten- Next, we distill knowledge from the fine-tuned
tion) to classify documents. Yang et al. (2018) BERTlarge into the much smaller LSTMreg , which
propose a generative approach for multi-label doc- represents the previous state of the art (Adhikari
ument classification, using encoder–decoder se- et al., 2019). We perform KD by using the
quence generation models (SGMs) for generat- training examples, along with minor augmenta-
ing labels for each document. Contrary to the tions to form the transfer set. In accordance with
previous papers, Adhikari et al. (2019) propose Hinton et al. (2015), we combine the two ob-
LSTMreg , a simple, properly-regularized single- jectives of classification using the target labels
layer BiLSTM, which represents the current state (Lclassif ication ) and distillation (Ldistill ) using the
of the art. soft targets, for each example of the transfer set.
In the current paradigm of pre-trained models, We use the target labels to minimize the stan-
methods like BERT (Devlin et al., 2019) and XL- dard cross-entropy or binary cross-entropy loss
Net (Yang et al., 2019) have been shown to achieve depending on the type of dataset (multi-label or
the state of the art in a variety of tasks includ- single-label). For the distillation objective, we
ing question answering, named entity recognition, minimize the Kullback–Leibler (KL) divergence
and natural language inference. However, these KL(p||q) where p and q are the class probabilities
models have a prohibitively large number of pa- produced by the student and the teacher models,
rameters and require substantial computational re- respectively. We define the final objective as:
sources, even to carry out a single inference pass.
L = Lclassif ication + λ · Ldistill (1)
Similar concerns related to high inference latency
or heavy run-time memory requirements have led where λ ∈ R weighs the losses’ contributions to
to a myriad of works, such as error-based weight- the final objective.
pruning (LeCun et al., 1990), and more recently,
model sparsification and channel pruning (Louizos 4 Experimental Setup
et al., 2018; Liu et al., 2017b).
Knowledge distillation (KD; Ba and Caruana, Using Hedwig,1 an open-source deep learning
2014; Hinton et al., 2015) has been shown to be toolkit with a number of implementations of doc-
an effective compression technique which “dis- ument classification models, we compare the fine-
tills” information learned by a larger model (the tuned BERT models against HAN, KimCNN,
teacher) into a smaller model (the student). KD XMLCNN, SGM, and LSTMreg . For simple yet
uses the class probabilities produced by a pre- competitive baselines, we run the default logis-
trained teacher, the soft targets, to train a student tic regression (LR) and support vector machine
model over a transfer set (the examples over which (SVM) implementations from Scikit-Learn (Pe-
distillation takes place). Being model agnostic, the dregosa et al., 2011), trained on the tf–idf vectors
approach is suitable for our study, as it enables the of the documents.
1
transfer of knowledge between different types of https://github.com/castorini/hedwig
Dataset C N W S For distillation, we train the LSTMreg model
to capture the learned representations from
Reuters 90 10,789 144.3 6.6
BERTlarge using the objective shown in Equa-
AAPD 54 55,840 167.3 1.0
tion (1). We use a batch size of 128 for the multi-
IMDB 10 135,669 393.8 14.4
label tasks and 64 for the single-label tasks. We
Yelp 2014 5 1,125,386 148.8 9.1
find the learning rates and dropout rates used in
Adhikari et al. (2019) to be optimal even for the
Table 1: Summary of the datasets. C denotes the num-
ber of classes in the dataset, N the number of samples, distillation process.
and W and S the average number of words and sen- To build an effective transfer set for distilla-
tences per document, respectively. tion as suggested by Hinton et al. (2015), we aug-
ment the training splits of the datasets by applying
POS-guided word swapping and random mask-
We use Nvidia Tesla V100 and P100 GPUs for ing (Tang et al., 2019). The transfer set sizes
fine-tuning BERT and run the rest of the experi- for Reuters, IMDB and AAPD are 3×, 4×, and
ments on RTX 2080 Ti and GTX 1080 GPUs. We 4× their training splits respectively, whereas only
use PyTorch 0.4.1 as the backend framework, and 1× (i.e., no data augmentation) the corresponding
Scikit-learn 0.19.2 for computing the tf–idf vec- training split for Yelp2014 due to computational
tors and implementing LR and SVMs. restrictions. We use a λ of 1 for the multi-label
4.1 Datasets datasets and 4 for the single-label datasets.

We use the following four datasets to evalu- 5 Results and Discussion


ate BERT: Reuters-21578 (Reuters; Apté et al.,
1994), arXiv Academic Paper dataset (AAPD; We report the mean F1 scores for multi-label
Yang et al., 2018), IMDB reviews, and Yelp datasets and accuracy for single-label datasets,
2014 reviews. Reuters and AAPD are multi-label along with the corresponding standard deviation,
datasets while documents in IMDB and Yelp ’14 across five runs in Table 2. We copy values for
contain only a single label. rows 1–9 from Adhikari et al. (2019). Due to re-
For Reuters, we use the standard ModApté source limitations, we report the scores from only
splits (Apté et al., 1994); for AAPD, we use the a single run for BERTbase and BERTlarge .
splits provided by Yang et al. (2018); for IMDB Consistent with Devlin et al. (2019), BERTlarge
and Yelp, following Yang et al. (2016), we ran- achieves state-of-the-art results on all four
domly sample 80% of the data for training and datasets, followed by BERTbase (see Table 2, rows
10% each for validation and test. 10 and 11). The considerably simpler LSTMreg
We summarize the statistics of the datasets used model (row 9) achieves high scores, coming close
in our study in Table 1. to the quality of BERTbase . However, it is worth
noting that all of the models above row 10 take
4.2 Training and Hyperparameters only a fraction of the time and memory required
While fine-tuning BERT, we optimize the num- for training the BERT models.
ber of epochs, batch size, learning rate, and max- Surprisingly, distilled LSTMreg (KD-LSTMreg ,
imum sequence length (MSL), the number of to- row 12) achieves parity with BERTbase on aver-
kens that documents are truncated to. We observe age for Reuters, AAPD, and IMDB. In fact, it
that model quality is quite sensitive to the number outperforms BERTbase (on both dev and test) in
of epochs, and thus the setting must be tailored at least one of the five runs. For Yelp, we see
for each dataset. We train on Reuters, AAPD, that KD-LSTMreg reduces the difference between
and IMDB for 30, 20, and 4 epochs, respectively. BERTbase and LSTMreg , but not to the same ex-
Due to resource constraints, we train on Yelp for tent as in the other datasets.
only one epoch. As is the case with Devlin et al. To put things in perspective, Table 3 reports the
(2019), we find that choosing a batch size of 16, inference times on the validation sets of all the
learning rate of 2×10−5 , and MSL of 512 tokens datasets. We calculate the inference times with
yields optimal performance on the validation sets batch size 128 for all the datasets on a single RTX
of all datasets. More details are provided in the 2080 Ti. The relative speedup achieved by KD-
appendix. LSTMreg is at least around 40× with respect to
Reuters AAPD IMDB Yelp ’14
# Model
Val. F1 Test F1 Val. F1 Test F1 Val. Acc. Test Acc. Val. Acc. Test Acc.
1 LR 77.0 74.8 67.1 64.9 43.1 43.4 61.1 60.9
2 SVM 89.1 86.1 71.1 69.1 42.5 42.4 59.7 59.6
3 KimCNN Repl. 83.5 ±0.4 80.8 ±0.3 54.5 ±1.4 51.4 ±1.3 42.9 ±0.3 42.7 ±0.4 66.5 ±0.1 66.1 ±0.6
4 KimCNN Orig. – – – – – 37.68 – 61.08
5 XML-CNN Repl. 88.8 ±0.5 86.2 ±0.3 70.2 ±0.7 68.7 ±0.4 – – – –
6 HAN Repl. 87.6 ±0.5 85.2 ±0.6 70.2 ±0.2 68.0 ±0.6 51.8 ±0.3 51.2 ±0.3 68.2 ±0.1 67.9 ±0.1
7 HAN Orig. – – – – – 49.43 – 70.53
8 SGM Orig. 82.5 ±0.4 78.8 ±0.9 – 71.02 – – – –
9 LSTMreg 89.1 ±0.8 87.0 ±0.5 73.1 ±0.4 70.5 ±0.5 53.4 ±0.2 52.8 ±0.3 69.0 ±0.1 68.7 ±0.1
10 BERTbase 90.5 89.0 75.3 73.4 54.4 54.2 72.1 72.0
11 BERTlarge 92.3 90.7 76.6 75.2 56.0 55.6 72.6 72.5
12 KD-LSTMreg 91.0 ±0.2 88.9 ±0.2 75.4 ±0.2 72.9 ±0.3 54.5 ±0.1 53.7 ±0.3 69.7 ±0.1 69.4 ±0.1

Table 2: Results for each model on the validation and test sets. Best values are bolded. Repl. reports the mean of
five runs from our reimplementations; Orig. refers to point estimates from † Yang et al. (2018), ‡ Yang et al. (2016),
and †† Tang et al. (2015). KD-LSTMreg represents the distilled LSTMreg using the fine-tuned BERTlarge .

0.90 0.70

0.85
Reuters/KD-LSTM 0.65 IMDB/KD-LSTM
Reuters/BERT Base IMDB/BERT Base
F1 score

F1 score

Reuters/BERT Large IMDB/BERT Large


0.80 AAPD/KD-LSTM Yelp14/KD-LSTM
AAPD/BERT Base 0.60 Yelp14/BERT Base
AAPD/BERT Large Yelp14/BERT Large
0.75
0.55
0.70
0.50
32 64 128 256 512 32 64 128 256 512
Number of hidden units Number of hidden units

Figure 1: Effectiveness of KD-LSTMreg vs. BERTbase and BERTlarge

Dataset LSTMreg BERTbase ters of BERTbase ) attains parity with BERTbase on


Reuters, while for AAPD, 512 hidden units (∼ 3%
Reuters 0.5 (1×) 30.3 (60×)
parameters of BERTbase ) are enough to overtake
AAPD 0.3 (1×) 15.8 (50×)
BERTbase .
IMDB 6.8 (1×) 243.6 (40×)
Yelp’14 20.6 (1×) 1829.9 (90×)
6 Conclusion and Future Work
Table 3: Comparison of inference latencies (seconds) In this paper we improve the baselines for docu-
on validation sets with batch size 128.
ment classification by fine-tuning BERT. We also
use the knowledge learned by BERT models to im-
prove the effectiveness of a single-layered light-
BERTbase . Additionally, Figure 1 shows the com- weight BiLSTM model, LSTMreg , using knowl-
parison between the number of parameters and edge distillation. In fact, we show that the dis-
prediction quality on the validation sets. These tilled LSTMreg model achieves BERTbase parity
plots convey the effectiveness of the KD-LSTMreg on a majority of datasets, resulting in over 30×
model with different numbers of hidden units: 32, compression in terms of the number of parameters
64, 128, 256, and 512. We find that KD-LSTMreg , and at least 40× faster inference times.
with just 256 hidden units (i.e., ∼ 1% parame- For future work, it would be interesting to study
the effects of distillation over a range of neural- Yoon Kim. 2014. Convolutional neural networks for
network architectures. Alternatively, formulating sentence classification. In Proceedings of the 2014
Conference on Empirical Methods in Natural Lan-
specific model compression techniques in the con-
guage Processing, pages 1746–1751.
text of transformer models deserves exploration.
Yann LeCun, John S. Denker, and Sara A. Solla. 1990.
Acknowledgments Optimal brain damage. In Advances in Neural In-
formation Processing Systems 2, pages 598–605.
This research was supported by the Natu-
ral Sciences and Engineering Research Council Jingzhou Liu, Wei-Cheng Chang, Yuexin Wu, and
Yiming Yang. 2017a. Deep learning for extreme
(NSERC) of Canada, and enabled by computa- multi-label text classification. In Proceedings of the
tional resources provided by Compute Ontario and 40th International ACM SIGIR Conference on Re-
Compute Canada. search and Development in Information Retrieval,
pages 115–124.

Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang,


References Shoumeng Yan, and Changshui Zhang. 2017b.
Ashutosh Adhikari, Achyudh Ram, Raphael Tang, and Learning efficient convolutional networks through
Jimmy Lin. 2019. Rethinking complex neural net- network slimming. 2017 IEEE International Con-
work architectures for document classification. In ference on Computer Vision, pages 2755–2763.
Proceedings of the 2019 Conference of the North
American Chapter of the Association for Computa- Christos Louizos, Max Welling, and Diederik P.
tional Linguistics: Human Language Technologies, Kingma. 2018. Learning sparse neural networks
Volume 1 (Long and Short Papers), pages 4046– through l0 regularization. arxiv/1712.01312.
4051.
Fabian Pedregosa, Gaël Varoquaux, Alexandre Gram-
Chidanand Apté, Fred Damerau, and Sholom M. fort, Vincent Michel, Bertrand Thirion, Olivier
Weiss. 1994. Automated learning of decision rules Grisel, Mathieu Blondel, Peter Prettenhofer, Ron
for text categorization. ACM Transactions on Infor- Weiss, Vincent Dubourg, Jake Vanderplas, Alexan-
mation Systems, 12(3):233–251. dre Passos, David Cournapeau, Matthieu Brucher,
Matthieu Perrot, and Édouard Duchesnay. 2011.
Jimmy Ba and Rich Caruana. 2014. Do deep nets really Scikit-learn: machine learning in Python. Journal
need to be deep? In Advances in Neural Information of Machine Learning Research, 12:2825–2830.
Processing Systems 27.
Jeffrey Pennington, Richard Socher, and Christopher
Yubo Chen, Liheng Xu, Kang Liu, Daojian Zeng, Manning. 2014. GloVe: global vectors for word rep-
and Jun Zhao. 2015. Event extraction via dy- resentation. In Proceedings of the 2014 Conference
namic multi-pooling convolutional neural networks. on Empirical Methods in Natural Language Pro-
In Proceedings of the 53rd Annual Meeting of the cessing, pages 1532–1543.
Association for Computational Linguistics and the
7th International Joint Conference on Natural Lan- Matthew Peters, Mark Neumann, Mohit Iyyer, Matt
guage Processing (Volume 1: Long Papers), pages Gardner, Christopher Clark, Kenton Lee, and Luke
167–176, Beijing, China. Zettlemoyer. 2018. Deep contextualized word rep-
resentations. In Proceedings of the 2018 Confer-
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and ence of the North American Chapter of the Asso-
Kristina Toutanova. 2019. BERT: pre-training of ciation for Computational Linguistics: Human Lan-
deep bidirectional transformers for language under- guage Technologies, Volume 1 (Long Papers), pages
standing. In Proceedings of the 2019 Conference of 2227–2237.
the North American Chapter of the Association for
Computational Linguistics: Human Language Tech- Alec Radford, Karthik Narasimhan, Tim Salimans, and
nologies, Volume 1 (Long and Short Papers), pages Ilya Sutskever. 2018. Improving language under-
4171–4186. standing by generative pre-training.

Geoffrey E. Hinton, Oriol Vinyals, and Jeffrey Dean. Duyu Tang, Bing Qin, and Ting Liu. 2015. Document
2015. Distilling the knowledge in a neural network. modeling with gated recurrent neural network for
arxiv/1503.02531. sentiment classification. In Proceedings of the 2015
Conference on Empirical Methods in Natural Lan-
Rie Johnson and Tong Zhang. 2015. Effective use guage Processing, pages 1422–1432.
of word order for text categorization with convolu-
tional neural networks. In Proceedings of the 2015 Raphael Tang, Yao Lu, Linqing Liu, Lili Mou, Olga
Conference of the North American Chapter of the Vechtomova, and Jimmy Lin. 2019. Distilling task-
Association for Computational Linguistics: Human specific knowledge from BERT into simple neural
Language Technologies, pages 103–112. networks. arxiv/1903.12136.
Pengcheng Yang, Xu Sun, Wei Li, Shuming Ma, Wei
Wu, and Houfeng Wang. 2018. SGM: sequence
generation model for multi-label classification. In
Proceedings of the 27th International Conference on
Computational Linguistics, pages 3915–3926.
Zhilin Yang, Zihang Dai, Yiming Yang, Jaime G.
Carbonell, Ruslan Salakhutdinov, and Quoc V. Le.
2019. XLNet: generalized autoregressive pretrain-
ing for language understanding. arxiv/1906.08237.
Zichao Yang, Diyi Yang, Chris Dyer, Xiaodong He,
Alex Smola, and Eduard Hovy. 2016. Hierarchi-
cal attention networks for document classification.
In Proceedings of the 2016 Conference of the North
American Chapter of the Association for Computa-
tional Linguistics: Human Language Technologies,
pages 1480–1489.
A Appendix as suggested in the original paper. On Reuters, us-
ing four epochs results in an F1 worse than even
A.1 Hyperparameter Analysis
logistic regression (Table 2, row 1).
MSL analysis. A decrease in the maximum se-
quence length (MSL) corresponds to only a minor
loss in F1 on Reuters (see top-left subplot in Fig-
ure 2), possibly due to Reuters having shorter doc-
uments. On IMDB (top-right subplot in Figure 2),
lowering the MSL corresponds to a drastic fall in
accuracy, suggesting that the entire document is
necessary for this dataset.
On the one hand, these results appear obvi-
ous. Alternatively, one can argue that, since IMDB
contains longer documents, truncating tokens may
hurt less. The top two subplots in Figure 2 show
that this is not the case, since truncating to even
256 tokens causes accuracy to fall lower than that
of the much smaller LSTMreg (see Table 2). From
these results, we conclude that any amount of trun-
cation is detrimental in document classification,
but the level of degradation may differ.

Reuters IMDB
93
56
92
54
91
52
90
50
89
48
88
46
87
44
86
42
128 256 512 128 256 512

Reuters AAPD
77
92
90 76

88 75
86 74
84
73
82
80 72
78 71
76
70
Theirs (4) Ours (30) Theirs (4) Ours (20)

Figure 2: Results on the validation set from varying the


MSL and the number of epochs.

Epoch analysis. The bottom two subplots in Fig-


ure 2 illustrate the F1 score of BERT fine-tuned
using different numbers of epochs for AAPD and
Reuters. Contrary to Devlin et al. (2019), who
achieve the state of the art on small datasets with
only a few epochs of fine-tuning, we find that
smaller datasets require many more epochs to con-
verge. On both the datasets (see Figure 2), we
see a significant drop in model quality when the
BERT models are fine-tuned for only four epochs,

You might also like