You are on page 1of 18

Get started Open in app

Follow 610K Followers

You have 2 free member-only stories left this month. Sign up for Medium and get an extra one

How to Train BERT


Quick-fire guide to training a transformer

James Briggs Jun 15, 2021 · 8 min read

Form like this requires pretraining — image by author.

T he success of transformer models is in large part thanks to the ability to take


a model that has been pre-trained on gigantic datasets by the likes of Google
and OpenAI — and apply them to our own use-cases.

Sometimes, this is all we need — we take the model and roll with it as is.

But at other times, we find that we really need to fine-tune the model. We need to
train it a little bit more on our specific use case.

Each transformer model is different, and fine-tuning for different use-cases is


different too — so we’ll focus on fine-tuning the core BERT model. With some
footnotes on how we could then modify it for a couple of the most common
applications too.

You can watch the video version of the article here:

How it Works
First things first, how does any of this work? I’m assuming some prior knowledge of
transformers and BERT here, if you have no idea what I’m talking about — check
out this article first.

The power of transformers stems from the common practice of transformers models
being pre-trained to a very high standard by big companies like Google and OpenAI.
Now, when I say pre-trained to a ver y high standard, the estimated training costs of
GPT-3 from OpenAI range from $4.6–12 million [1][2].

I don’t have a spare $12 million to spend on training models, do you?

Often, the raw pre-trained model is more than enough for our needs and we don’t
need to worr y about training further.

But sometimes, we may need to — fortunately transformers are built with this in
mind. For BERT, we can split the possibility of further training into two categories.

First, we have fine-tuning the core BERT model itself. This approach consists of
using the same training approach used by Google when training the original model
— which we’ll cover in more depth in a moment.

(Left) BERT with a classification head, and (right) BERT with a question-answering head.

Second, we can add different heads to our model which gives BERT new abilities.
These are extra layers at the end of our model that modify the outputs for different
use-cases. For example, we would use different heads for question-answering or
classification.

We will be focusing on fine-tuning the core BERT model in this article — which
allows us to fine-tune BERT to better understand the specific style of language in
our use-cases.

Fine-Tuning the Core


The core of BERT is trained using two methods, next sentence prediction (NSP) and
masked-language modeling (MLM).

1. Next Sentence Prediction consists of taking pairs of sentences as inputs to


the model, some of these pairs will be true pairs, others will not.

Two consecutive sentences result in a ‘true pair’, anything else is not a true pair.

BERTs task here is to accurately identify which pairs genuinely are pairs, and which
are not.

Remember how I said we can train BERT using different heads? Well NSP (and
MLM) use special heads too. The head being used here processes output from a
classifier token into a dense NN — outputting two classes.

Our classification head dense layer consumes the output from the [CLS] (classifier)
token position — used in classification tasks.
The output from this [CLS] token is a 768-dimensional vector, which is passed to
our dense NN layer with two nodes — our IsNextSentence and NotNextSentence

classes.

A high-level view of the NSP task in BERT.

Those two outputs are our true/false predictions as to whether BERT believes
sentence B comes after sentence A. Index 0 tells us that BERT believes sentence B
does come after sentence A.

After training, the NSP head is discarded — all we keep are the fine-tuned weights
within the many BERT layers.

2.
is.
Masked-Language Modeling consists of taking a chunk of text, masking a
given number of tokens, and asking BERT to predict what the masked word
The original text is processed by a masking operation, which replaces random tokens with the [MASK] token.

15% of the words in each sequence are masked with the [MASK] token.

A classification head is attached to the model and each token will feed into a
feedfor ward neural net, followed by a softmax function. The output dimensionality
for each token is equal to the vocab size.

A high-level view of the MLM process.

That means that from each token position, we will get an output prediction of the
highest probability token. Which we translate into a specific word using our
vocabular y.

During training, the predictions for tokens that are not masked are ignored when
calculating the loss function.

Again, as with NSP, the MLM head is discarded after training — leaving us with
optimized model weights.
In Code
We know how fine-tuning with NSP and MLM works, but how exactly do we apply
that in code?

Well, we can start by importing transformers, PyTorch, and our training data —
Meditations (find a copy of the training data here).

Now we have a list of paragraphs in text — some, but not all, contain multiple
sentences. Which we need when building our NSP training data.
Preparing For NSP
To prepare our data for NSP, we need to create a mix of non-random sentences
(where the two sentences were originally together) — and random sentences.

For this, we’ll create a bag of sentences extracted from text which we can then
randomly select a sentence from when creating a random NotNextSentence pair.

Our bag contains the same data as text but split by sentence — as identified through the use of period
characters.

After creating our bag we can go ahead and create our 50/50 random/non-random
NSP training data. For this, we will create a list of sentence As, sentence Bs, and
their respective IsNextSentence or NotNextSentence labels.
We can see in the console output that we have label 1 representing random
sentences ( NotNextSentence ) and 0 representing non-random sentences
( IsNextSentence ).

Tokenization
We can now tokenize our data. As is typical with BERT models, we truncate/pad our
sequences to a length of 512 tokens.
There are a few things we should take note of here. Because we tokenized two
sentences, our tokenizer automatically applied 0 values to sentence A and 1 values
to sentence B in the token_type_ids tensor. The trailing zeros are aligned to the
padding tokens.

Secondly, in the input_ids tensor, the tokenizer automatically placed a SEP token
(102) between these two sentences — marking the boundar y between them both.

BERT needs to see both of these when performing NSP.

NSP Labels
Our NSP labels must be placed within a tensor called next_sentence_label. We create
this easily by taking our label variable, and converting it into a torch.LongTensor

— which must also be transposed using .T :


Masking For MLM
For MLM we need to clone our current input_ids tensor to create a MLM labels
tensor — then we move onto masking ~15% of tokens in the input_ids tensor.
Now that we that clone for our labels, we mask tokens in input_ids.
Note that there are a few rules we’ve added here, by adding the additional logic
when creating mask_arr — we are ensuring that we don’t mask any special tokens
— such as CLS (101), SEP (102), and PAD (0) tokens.

Dataloader
All of our input and label tensors are ready — all we need to do now is format them
into a PyTorch dataset object so that it can be loaded into a PyTorch Dataloader —
which will feed batches of data into our model during training.
The dataloader expects the __len__ method for checking the total number of
samples within our dataset, and the __getitem__ method for extracting samples.

Setup For Training


The last step before moving onto our training loop is preparing our model training
setup.

We first check if we have a GPU available, if so we move the model over to it for
training. Then we activate training parameters in our model and initialize an Adam
optimizer with weighted decay.

Training

Finally, we’re onto training our model. We train for two epochs, and use tqdm to
create a progress bar for our training loop.
Within the loop we:

Initialize gradients, so that we are not starting from the gradients calculated in
the previous step.

Move all batch tensors to the selected device (GPU or CPU).

Feed ever ything into the model and extract loss.

Use loss.backward() to calculate the loss for each parameter.

Update parameter weights based on the calculated loss.

Print relevant information to the progress bar ( loop ).

And that’s it, with that we’ve fine-tuned our model using both MLM and NSP!

That’s all for this article on fine-tuning BERT using masked-language modeling and
next sentence prediction. We’ve covered what MLM and NSP are, how they work,
and how we can fine-tune our models with them.

There’s a lot to fine-tuning BERT, but the concept and implementations are not too
complex — while being incredibly powerful.

Using what we’ve learned here, we can take the best models in NLP and fine-tune
them to our more domain-specific language use-cases — needing nothing more
than unlabelled text — often an easy data source to find.

I hope you enjoyed this article! If you have any questions, let me know via Twitter or
in the comments below. If you’d like more content like this, I post on YouTube too.

Thanks for reading!


References
[1] B. Dickson, The untold stor y of GPT-3 is the transformation of OpenAI (2020),
TechTalks

[2] K. Wiggers, OpenAI’s massive GPT-3 model is impressive, but size isn’t ever ything
(2020), VentureBeat

Jupyter Notebook

If you’re interested in learning more about the logic behind MLM and NSP —and
transformers in general, check out my Transformers for NLP course:

🤖 70% Discount on the NLP With Transformers Course

*All images are by the author except where stated otherwise

Sign up for The Variable


By Towards Data Science

Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials
and cutting-edge research to original features you don't want to miss. Take a look.

Get this newsletter

NLP Machine Learning Python Data Science Artificial Intelligence

189 2

More from Towards Data Science Follow

Y h f d t i A M di bli ti h i t id d
Your home for data science. A Medium publication sharing concepts, ideas and
codes.

Read more from Towards Data Science

More From Medium

Object Detection using CNN Model


Ajmal Muhammed

Implementing The Perceptron Algorithm From Scratch


In Python
Niranjan Kumar in HackerNoon.com

A Tale of Imbalanced classes


sharad jain

Intuition behind bias and variance error


Pradeep Bansal in Artificial Intelligence in Plain English

Humble YOLO implementation in Keras


Emmanuel Caradec

Supervised Learning
Society of AI

Classification Models — An Overview


Ravi Teja Gundimeda

“Real life” DAG simulation using the simMixedDAG


package
Iyar Lin in The Startup
About Write Help Legal

Get the Medium app

You might also like