Transformer-based encoder-decoder models were proposed in Vaswani et
al. (2017) and have recently
experienced a surge of interest, e.g. Lewis et al.
(2019), Raffel et al.
(2019), Zhang et al.
(2020), Zaheer et al.
(2020), Yan et al.
(2020).
Just like BERT and GPT2, massive pre-trained encoder-decoder models
have shown to significantly boost performance on quite a lot of
sequence-to-sequence tasks Lewis et al.
(2019), Raffel et al.
(2019). Nonetheless, because of the large
computational cost attached to pre-training encoder-decoder models, the
development of such models is principally limited to large corporations and
institutes.
In Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
(2020), Sascha Rothe, Shashi
Narayan and Aliaksei Severyn initialize encoder-decoder model with
pre-trained encoder and/or decoder-only checkpoints (e.g. BERT,
GPT2) to skip the costly pre-training. The authors show that such
warm-started encoder-decoder models yield competitive results to large
pre-trained encoder-decoder models, reminiscent of
T5, and
Pegasus on multiple
sequence-to-sequence tasks at a fraction of the training cost.
On this notebook, we’ll explain intimately how encoder-decoder models
could be warm-started, give practical suggestions based on Rothe et al.
(2020), and at last go over a
complete code example showing the best way to warm-start encoder-decoder models
with 🤗Transformers.
This notebook is split into 4 parts:
- Introduction – Short summary of pre-trained language models in
NLP and the necessity for warm-starting encoder-decoder models. - Warm-starting encoder-decoder models (Theory) – Illustrative
explanation on how encoder-decoder models are warm-started? - Warm-starting encoder-decoder models (Evaluation) – Summary of
Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks (2020) – What model
combos are effective to warm-start encoder-decoder models; How
does it differ from task to task? - Warm-starting encoder-decoder models with 🤗Transformers
(Practice) – Complete code example showcasing in-detail the best way to
use theEncoderDecoderModelframework to warm-start
transformer-based encoder-decoder models.
It is extremely really useful (probably even crucial) to have read this
blog
post
about transformer-based encoder-decoder models.
Let’s start by giving some back-ground on warm-starting encoder-decoder
models.
Introduction
Recently, pre-trained language models have revolutionized the
field of natural language processing (NLP).
The primary pre-trained language models were based on recurrent neural
networks (RNN) as proposed Dai et al.
(2015). Dai et. al showed that
pre-training an RNN-based model on unlabelled data and subsequently
fine-tuning it on a particular task yields higher results than
training a randomly initialized model directly on such a task. Nonetheless,
it was only in 2018, when pre-trained language models turn into widely
accepted in NLP. ELMO by Peters et
al. and ULMFit by Howard et
al. were the primary pre-trained
language model to significantly improve the state-of-the-art on an array
of natural language understanding (NLU) tasks. Just a few months
later, OpenAI and Google published transformer-based pre-trained
language models, called GPT by Radford et
al.
and BERT by Devlin et al.
respectively. The improved efficiency of transformer-based language
models over RNNs allowed GPT2 and BERT to be pre-trained on massive
amounts of unlabeled text data. Once pre-trained, BERT and GPT were
shown to require little or no fine-tuning to shatter state-of-art results
on greater than a dozen NLU tasks .
The potential of pre-trained language models to effectively transfer
task-agnostic knowledge to task-specific knowledge turned out to be
an ideal catalyst for NLU. Whereas engineers and researchers previously
had to coach a language model from scratch, now publicly available
checkpoints of enormous pre-trained language models could be fine-tuned at a
fraction of the price and time. This could save thousands and thousands in industry and
allows for faster prototyping and higher benchmarks in research.
Pre-trained language models have established a brand new level of performance
on NLU tasks and increasingly research has been built upon leveraging
such pre-trained language models for improved NLU systems. Nonetheless,
standalone BERT and GPT models have been less successful for
sequence-to-sequence tasks, e.g. text-summarization, machine
translation, sentence-rephrasing, etc.
Sequence-to-sequence tasks are defined as a mapping from an input
sequence to an output sequence of
a-priori unknown output length . Hence, a sequence-to-sequence
model should define the conditional probability distribution of the
output sequence conditioned on the input sequence :
Without lack of generality, an input word sequence of words is
hereby represented by the vector sequnece and an output
sequence of words as .
Let’s examine how BERT and GPT2 could be fit to model sequence-to-sequence
tasks.
BERT
BERT is an encoder-only model, which maps an input sequence to a contextualized encoded sequence :
BERT’s contextualized encoded sequence
can then further be processed by a classification layer for NLU
classification tasks, reminiscent of sentiment evaluation, natural language
inference, etc. To accomplish that, the classification layer, i.e. typically a
pooling layer followed by a feed-forward layer, is added as a final
layer on top of BERT to map the contextualized encoded sequence to a category :
It has been shown that adding a pooling- and classification layer,
defined as , on top of a pre-trained BERT model and subsequently fine-tuning the entire model can yield
state-of-the-art performances on quite a lot of NLU tasks, cf. to BERT
by Devlin et al..
Let’s visualize BERT.
The BERT model is shown in grey. The model stacks multiple BERT
blocks, each of which consists of bi-directional self-attention
layers (shown within the lower a part of the red box) and two feed-forward
layers (short within the upper a part of the red box).
Each BERT block makes use of bi-directional self-attention to
process an input sequence (shown
in light grey) to a more “refined” contextualized output sequence (shown in barely darker grey) . The contextualized output sequence of the ultimate BERT block,
i.e. , can then be mapped to a single
output class by adding a task-specific classification layer (shown
in orange) as explained above.
Encoder-only models can only map an input sequence to an output
sequence of a priori known output length. In conclusion, the output
dimension doesn’t rely on the input sequence, which makes it
disadvantageous and impractical to make use of encoder-only models for
sequence-to-sequence tasks.
As for all encoder-only models, BERT’s architecture corresponds
exactly to the architecture of the encoder a part of transformer-based
encoder-decoder models as shown within the “Encoder” section within the
Encoder-Decoder
notebook.
GPT2
GPT2 is a decoder-only model, which makes use of uni-directional
(i.e. “causal”) self-attention to define a mapping from an input
sequence to a “next-word” logit vector
sequence :
By processing the logit vectors with the softmax
operation, the model can define the probability distribution of the word
sequence . To be exact, the probability distribution
of the word sequence could be factorized into
conditional “next word” distributions:
hereby
presents the probability distribution of the following word
given all previous words
and is defined because the softmax operation applied on the logit vector . To summarize, the next equations hold true.
For more detail, please discuss with the
decoder section
of the encoder-decoder blog post.
Let’s visualize GPT2 now as well.
Analogous to BERT, GPT2 consists of a stack of GPT2 blocks. In
contrast to BERT block, GPT2 block makes use of uni-directional
self-attention to process some input vectors (shown in light blue on the
bottom right) to an output vector sequence (shown in darker blue on
the highest right). Along with the GPT2 block stack, the model also has
a linear layer, called LM Head, which maps the output vectors of the
final GPT2 block to the logit vectors . As mentioned earlier, a logit
vector can then be used to sample of latest input vector .
GPT2 is principally used for open-domain text generation. First, an input
prompt is fed to the model to yield the conditional
distribution . Then the
next word is sampled from the distribution (represented
by the grey arrows within the graph above) and consequently append to the
input. In an auto-regressive fashion the word can
then be sampled from and so forth.
GPT2 is due to this fact well-suited for language generation, but less so for
conditional generation. By setting the input prompt equal to the sequence input ,
GPT2 can thoroughly be used for conditional generation. Nonetheless, the
model architecture has a fundamental drawback in comparison with the
encoder-decoder architecture as explained in Raffel et al.
(2019) on page 17. Briefly,
uni-directional self-attention forces the model’s representation of the
sequence input to be unnecessarily limited since cannot rely on .
Encoder-Decoder
Because encoder-only models require to know the output length a
priori, they appear unfit for sequence-to-sequence tasks. Decoder-only
models can function well for sequence-to-sequence tasks, but additionally have
certain architectural limitations as explained above.
The present predominant approach to tackle sequence-to-sequence tasks
are transformer-based encoder-decoder models – often also called
seq2seq transformer models. Encoder-decoder models were introduced in
Vaswani et al. (2017) and since then
have been shown to perform higher on sequence-to-sequence tasks than
stand-alone language models (i.e. decoder-only models), e.g. Raffel
et al. (2020). In essence, an
encoder-decoder model is the mixture of a stand-alone encoder,
reminiscent of BERT, and a stand-alone decoder model, reminiscent of GPT2. For more
details on the precise architecture of transformer-based encoder-decoder
models, please discuss with this blog
post.
Now, we all know that freely available checkpoints of enormous pre-trained
stand-alone encoder and decoder models, reminiscent of BERT and GPT, can
boost performance and reduce training cost for a lot of NLU tasks, We also
know that encoder-decoder models are essentially the mixture of
stand-alone encoder and decoder models. This naturally brings up the
query of how one can leverage stand-alone model checkpoints for
encoder-decoder models and which model combos are most performant
on certain sequence-to-sequence tasks.
In 2020, Sascha Rothe, Shashi Narayan, and Aliaksei Severyn investigated
exactly this query of their paper Leveraging Pre-trained
Checkpoints for Sequence Generation
Tasks. The paper offers an ideal
evaluation of various encoder-decoder model combos and fine-tuning
techniques, which we’ll study in additional detail later.
Composing an encoder-decoder model of pre-trained stand-alone model
checkpoints is defined as warm-starting the encoder-decoder model. The
following sections show how warm-starting an encoder-decoder model works
in theory, how one can put the speculation into practice with 🤗Transformers,
and in addition gives practical suggestions for higher performance.
A pre-trained language model is defined as a neural network:
- that has been trained on unlabeled text data, i.e. in a
task-agnostic, unsupervised fashion, and - that processes a sequence of input words right into a context-dependent
embedding. E.g. the continuous bag-of-words and skip-gram
model from Mikolov et al. (2013)
will not be considered a pre-trained language model since the
embeddings are context-agnostic.
Advantageous-tuning is defined because the task-specific training of a
model that has been initialized with the weights of a pre-trained
language model.
The input vector corresponds hereby to the embedding vector required to predict the very first output
word .
Without lack of generalitiy, we exclude the normalization layers
to not clutter the equations and illustrations.
For more detail on why uni-directional self-attention is used for
“decoder-only” models, reminiscent of GPT2, and the way sampling works exactly,
please discuss with the
decoder section
of the encoder-decoder blog post.
Warm-starting encoder-decoder models (Theory)
Having read the introduction, we at the moment are aware of encoder-only–
and decoder-only models. We’ve got noticed that the encoder-decoder
model architecture is basically a composition of a stand-alone
encoder model and a stand-alone decoder model, which led us to the
query of how one can warm-start encoder-decoder models from
stand-alone model checkpoints.
There are multiple possibilities to warm-start an encoder-decoder model.
One can
- initialize each the encoder and decoder part from an encoder-only
model checkpoint, e.g. BERT, - initialize the encoder part from an encoder-only model checkpoint,
e.g. BERT, and the decoder part from and a decoder-only
checkpoint, e.g. GPT2, - initialize only the encoder part with an encoder-only model
checkpoint, or - initialize only the decoder part with a decoder-only model
checkpoint.
In the next, we’ll put the concentrate on possibilities 1. and a couple of.
Possibilities 3. and 4. are trivial after having understood the primary
two.
Recap Encoder-Decoder Model
First, let’s do a fast recap of the encoder-decoder architecture.
The encoder (shown in green) is a stack of encoder blocks. Each
encoder block consists of a bi-directional self-attention layer,
and two feed-forward layers . The decoder (shown in orange) is a
stack of decoder blocks, followed by a dense layer, called LM Head.
Each decoder block consists of a uni-directional self-attention
layer, a cross-attention layer, and two feed-forward layers.
The encoder maps the input sequence to a
contextualized encoded sequence within the
very same way BERT does. The decoder then maps the contextualized
encoded sequence and a goal sequence to the logit vectors . Analogous
to GPT2, the logits are then used to define the distribution of the
goal sequence conditioned on the input sequence by the use of a softmax operation.
To place it into mathematical terms, first, the conditional distribution
is factorized into conditional distributions of the following word by Bayes’ rule.
Each “next-word” conditional distributions is thereby defined by the
softmax of the logit vector as follows.
For more detail, please discuss with the Encoder-Decoder
notebook.
Warm-starting Encoder-Decoder with BERT
Let’s now illustrate how a pre-trained BERT model could be used to
warm-start the encoder-decoder model. BERT’s pre-trained weight
parameters are used to each initialize the encoder’s weight parameters
in addition to the decoder’s weight parameters. To accomplish that, BERT’s
architecture is in comparison with the encoder’s architecture and all layers
of the encoder that also exist in BERT shall be initialized with the
pre-trained weight parameters of the respective layers. All layers of
the encoder that don’t exist in BERT will simply have their weight
parameters be randomly initialized.
Let’s visualize.
We will see that the encoder architecture corresponds 1-to-1 to BERT’s
architecture. The burden parameters of the bi-directional
self-attention layer and the 2 feed-forward layers of all
encoder blocks are initialized with the load parameters of the
respective BERT blocks. That is illustrated examplary for the second
encoder block (red boxes at bottow) whose weight parameters and are set to BERT’s weight
parameters and , respectively at
initialization.
Before fine-tuning, the encoder due to this fact behaves exactly like a
pre-trained BERT model. Assuming the input sequence (shown in green) passed to the
encoder is the same as the input sequence (shown
in grey) passed to BERT, which means the respective output vector
sequences
(shown in darker green) and
(shown in darker grey) also must be equal.
Next, let’s illustrate how the decoder is warm-started.
The architecture of the decoder is different from BERT’s architecture
in 3 ways.
-
First, the decoder must be conditioned on the contextualized
encoded sequence by the use of
cross-attention layers. Consequently, randomly initialized
cross-attention layers are added between the self-attention layer
and the 2 feed-forward layers in each BERT block. That is
represented exemplary for the second block by
and illustrated
by the newly added fully connected graph in red within the lower red box
on the best. This necessarily changes the behavior of every modified
BERT block in order that an input vector, e.g. now
yields a random output vector (highlighted by the
red border across the output vector ). -
Second, BERT’s bi-directional self-attention layers must be
modified to uni-directional self-attention layers to comply with
auto-regressive generation. Because each the bi-directional and the
uni-directional self-attention layer are based on the identical key,
query and value projection weights, the decoder’s
self-attention layer weights could be initialized with BERT’s
self-attention layer weights. E.g. the query, key and value weight
parameters of the decoder’s uni-directional self-attention layer
are initialized with those of BERT’s bi-directional self-attention
layer Nonetheless, in uni-directional self-attention each token only
attends to all previous tokens, in order that the decoder’s
self-attention layers yield different output vectors than BERT’s
self-attention layers although they share the identical weights.
Compare e.g., the decoder’s causally connected graph in the best
box versus BERT’s fully connected graph within the left box. -
Third, the decoder outputs a sequence of logit vectors
as a way to define the conditional probability
distribution
.
In consequence, a LM Head layer is added on top of the last decoder
block. The burden parameters of the LM Head layer often
correspond to the load parameters of the word embedding
and thus will not be randomly initialized.
That is illustrated in the highest by the initialization
.
To conclude, when warm-starting the decoder from a pre-trained BERT
model only the cross-attention layer weights are randomly initialized.
All other weights including those of the self-attention layer and LM
Head are initialized with BERT’s pre-trained weight parameters.
Having warm-stared the encoder-decoder model, the weights are then
fine-tuned on a sequence-to-sequence downstream task, reminiscent of
summarization.
Warm-starting Encoder-Decoder with BERT and GPT2
As an alternative of warm-starting each the encoder and decoder with a BERT
checkpoint, we will as a substitute leverage the BERT checkpoint for the encoder
and a GPT2 checkpoint for the decoder. At first glance, a decoder-only
GPT2 checkpoint appears to be better-suited to warm-start the decoder
since it has already been trained on causal language modeling and uses
uni-directional self-attention layers.
Let’s illustrate how a GPT2 checkpoint could be used to warm-start the
decoder.
We will see that decoder is more just like GPT2 than it’s to BERT. The
weight parameters of decoder’s LM Head can directly be initialized
with GPT2’s LM Head weight parameters, e.g. .
As well as, the blocks of the decoder and GPT2 each make use of
uni-directional self-attention in order that the output vectors of the
decoder’s self-attention layer are such as GPT2’s output vectors
assuming the input vectors are the identical, e.g. . In contrast to the
BERT-initialized decoder, the GPT2-initialized decoder, due to this fact, keeps
the causal connected graph of the self-attention layer as could be seen in
the red boxes on the underside.
Nevertheless, the GPT2-initialized decoder also has to condition the
decoder on . Analoguos to the
BERT-initialized decoder, randomly initialized weight parameters for the
cross-attention layer are due to this fact added to every decoder block. That is
illustrated e.g. for the second encoder block by .
Regardless that GPT2 resembles the decoder a part of an encoder-decoder model
greater than BERT, a GPT2-initialized decoder will even yield random logit
vectors without fine-tuning because of randomly
initialized cross-attention layers in every decoder block. It could be
interesting to research whether a GPT2-initialized decoder yields
higher results or could be fine-tuned more efficiently.
Encoder-Decoder Weight Sharing
In Raffel et al. (2020), the
authors show that a randomly-initialized encoder-decoder model that
shares the encoder’s weights with the decoder, and due to this fact reduces
the memory footprint by half, performs only barely worse than its
“non-shared” version. Sharing the encoder’s weights with the decoder
signifies that all layers of the decoder which might be found at the identical position
within the encoder share the identical weight parameters, i.e. the identical node in
the network’s computation graph.
E.g. the query, key, and value projection matrices of the
self-attention layer within the third encoder block, defined as , , are equivalent to the
respective query, key, and value projections matrices of the
self-attention layer within the third decoder block :
In consequence, the important thing projection weights
are updated twice for every backward propagation pass – once when the
gradient is backpropagated through the third decoder block and once when
the gradient is backprapageted thourgh the third encoder block.
In the identical way, we will warm-start an encoder-decoder model by sharing
the encoder weights with the decoder. With the ability to share the weights
between the encoder and decoder requires the decoder architecture
(excluding the cross-attention weights) to be equivalent to the encoder
architecture. Due to this fact, encoder-decoder weight sharing is barely
relevant if the encoder-decoder model is warm-started from a single
encoder-only pre-trained checkpoint.
Great! That was the speculation about warm-starting encoder-decoder models.
Let’s now take a look at some results.
Without lack of generality, we exclude the normalization layers
to not clutter the equations and illustrations.
For more detail on how self-attention layers function, please
discuss with this
section of the
transformer-based encoder-decoder model blog post for the encoder-part
(and this section
for the decoder part respectively).
Warm-starting encoder-decoder models (Evaluation)
On this section, we’ll summarize the findings on warm-starting
encoder-decoder models as presented in Leveraging Pre-trained
Checkpoints for Sequence Generation
Tasks by Sascha Rothe, Shashi
Narayan, and Aliaksei Severyn. The authors compared the performance of
warm-started encoder-decoder models to randomly initialized
encoder-decoder models on multiple sequence-to-sequence tasks, notably
summarization, translation, sentence splitting, and sentence
fusion.
To be more precise, the publicly available pre-trained checkpoints of
BERT, RoBERTa, and GPT2 were leveraged in several
variations to warm-start an encoder-decoder model. E.g. a
BERT-initialised encoder was paired with a BERT-initialized decoder
yielding a BERT2BERT model or a RoBERTa-initialized encoder was paired
with a GPT2-initialized decoder to yield a RoBERTa2GPT2 model.
Moreover, the effect of sharing the encoder and decoder weights (as
explained within the previous section) was investigated for RoBERTa, i.e.
RoBERTaShare, and for BERT, i.e. BERTShare. Randomly or partly
randomly initialized encoder-decoder models were used as a baseline,
reminiscent of a totally randomly initialized encoder-decoder model, coined
Rnd2Rnd or a BERT-initialized decoder paired with a randomly
initialized encoder, defined as Rnd2BERT.
The next table shows an entire list of all investigated model
variants including the variety of randomly initialized weights, i.e.
“random”, and the variety of weights initialized from the respective
pre-trained checkpoints, i.e. “leveraged”. All models are based on a
12-layer architecture with 768-dim hidden size embeddings, corresponding
to the bert-base-cased, bert-base-uncased, roberta-base, and
gpt2 checkpoints within the 🤗Transformers model hub.
| Model | random | leveraged | total |
|---|---|---|---|
| Rnd2Rnd | 221M | 0 | 221M |
| Rnd2BERT | 112M | 109M | 221M |
| BERT2Rnd | 112M | 109M | 221M |
| Rnd2GPT2 | 114M | 125M | 238M |
| BERT2BERT | 26M | 195M | 221M |
| BERTShare | 26M | 109M | 135M |
| RoBERTaShare | 26M | 126M | 152M |
| BERT2GPT2 | 26M | 234M | 260M |
| RoBERTa2GPT2 | 26M | 250M | 276M |
The model Rnd2Rnd, which is predicated on the BERT2BERT architecture,
incorporates 221M weight parameters – all of that are randomly initialized.
The opposite two “BERT-based” baselines Rnd2BERT and BERT2Rnd have
roughly half of their weights, i.e. 112M parameters, randomly
initialized. The opposite 109M weight parameters are leveraged from the
pre-trained bert-base-uncased checkpoint for the encoder- or decoder
part respectively. The models BERT2BERT, BERT2GPT2, and
RoBERTa2GPT2 have all of their encoder weight parameters leveraged
(from bert-base-uncased, roberta-base respectively) and many of the
decoder weight parameter weights as well (from gpt2,
bert-base-uncased respectively). 26M decoder weight parameters, which
correspond to the 12 cross-attention layers, are thereby randomly
initialized. RoBERTa2GPT2 and BERT2GPT2 are in comparison with the Rnd2GPT2
baseline. Also, it needs to be noted that the shared model variants
BERTShare and RoBERTaShare have significantly fewer parameters
because all encoder weight parameters are shared with the respective
decoder weight parameters.
Experiments
The above models were trained and evaluated on 4 sequence-to-sequence
tasks of accelerating complexity: sentence-level fusion, sentence-level
splitting, translation, and abstractive summarization. The next
table shows which datasets were used for every task.
Depending on the duty, a rather different training regime was used.
E.g. in response to the scale of the dataset and the particular task, the
number of coaching steps ranges from 200K to 500K, the batch size is about
to either 128 or 256, the input length ranges from 128 to 512 and the
output length varies between 32 to 128. It shall be emphasized nonetheless
that inside each task, all models were trained and evaluated using the
same hyperparameters to make sure a good comparison. For more information
on the task-specific hyperparameter settings, the reader is suggested to
see the Experiments section within the
paper.
We’ll now give a condensed overview of the outcomes for every task.
Sentence Fusion and -Splitting (DiscoFuse, WikiSplit)
Sentence Fusion is the duty of mixing multiple sentences right into a
single coherent sentence. E.g. the 2 sentences:
As a run-blocker, Zeitler moves relatively well. Zeitler too often
struggles at the purpose of contact in space.
needs to be connected with a fitting linking word, reminiscent of:
As a run-blocker, Zeitler moves relatively well. Nonetheless, he
too often struggles at the purpose of contact in space.
As could be seen the linking word “nonetheless” provides a coherent
transition from the primary sentence to the second. A model that’s
able to generating such a linking word has arguably learned to infer
that the 2 sentences above contrast to one another.
The inverse task is named Sentence splitting and consists of
splitting a single complex sentence into multiple simpler ones that
together retain the identical meaning. Sentence splitting is taken into account as an
essential task in text simplification, cf. to Botha et al.
(2018).
For example, the sentence:
Street Rod is the primary in a series of two games released for the PC
and Commodore 64 in 1989
could be simplified into
Street Rod is the primary in a series of two games . It was released
for the PC and Commodore 64 in 1989
It will probably be seen that the long sentence tries to convey two essential
pieces of knowledge. One is that the sport was the primary of two games
being released for the PC, and the second being the yr during which it was
released. Sentence splitting, due to this fact, requires the model to
understand which a part of the sentence needs to be divided into two
sentences, making the duty tougher than sentence fusion.
A standard metric to judge the performance of models on sentence fusion
resp. -splitting tasks is SARI (Wu et al.
(2016), which is broadly
based on the F1-score of label and model output.
Let’s examine how the models perform on sentence fusion and -splitting.
| Model | 100% DiscoFuse (SARI) | 10% DiscoFuse (SARI) | 100% WikiSplit (SARI) |
|---|---|---|---|
| Rnd2Rnd | 86.9 | 81.5 | 61.7 |
| Rnd2BERT | 87.6 | 82.1 | 61.8 |
| BERT2Rnd | 89.3 | 86.1 | 63.1 |
| Rnd2GPT2 | 86.5 | 81.4 | 61.3 |
| BERT2BERT | 89.3 | 86.1 | 63.2 |
| BERTShare | 89.2 | 86.0 | 63.5 |
| RoBERTaShare | 89.7 | 86.0 | 63.4 |
| BERT2GPT2 | 88.4 | 84.1 | 62.4 |
| RoBERTa2GPT2 | 89.9 | 87.1 | 63.2 |
| — | — | — | — |
| RoBERTaShare (large) | 90.3 | 87.7 | 63.8 |
The primary two columns show the performance of the encoder-decoder models
on the DiscoFuse evaluation data. The primary column states the outcomes of
encoder-decoder models trained on all (100%) of the training data, while
the second column shows the outcomes of the models trained only on 10% of
the training data. We observe that warm-started models perform
significantly higher than the randomly initialized baseline models
Rnd2Rnd, Rnd2Bert, and Rnd2GPT2. A warm-started RoBERTa2GPT2
model trained only on 10% of the training data is on par with an
Rnd2Rnd model trained on 100% of the training data. Interestingly, the
Bert2Rnd baseline performs equally well as a totally warm-started
Bert2Bert model, which indicates that warm-starting the encoder-part
is more practical than warm-starting the decoder-part. The perfect results
are obtained by RoBERTa2GPT2, followed by RobertaShare. Sharing
encoder and decoder weight parameters does appear to barely increase the
model’s performance.
On the tougher sentence splitting task, an identical pattern
emerges. Warm-started encoder-decoder models significantly outperform
encoder-decoder models whose encoder is randomly initialized and
encoder-decoder models with shared weight parameters yield higher
results than those with uncoupled weight parameters. On sentence
splitting the BertShare models yields the perfect performance closely
followed by RobertaShare.
Along with the 12-layer model variants, the authors also trained and
evaluated a 24-layer RobertaShare (large) model which outperforms all
12-layer models significantly.
Machine Translation (WMT14)
Next, the authors evaluated warm-started encoder-decoder models on the
probably commonest benchmark in machine translation (MT) – the En De and De En WMT14 dataset. On this notebook, we
present the outcomes on the newstest2014 eval dataset. Since the
benchmark requires the model to know each an English and a German
vocabulary the BERT-initialized encoder-decoder models were warm-started
from the multilingual pre-trained checkpoint
bert-base-multilingual-cased. Because there isn’t any publicly available
multilingual RoBERTa checkpoint, RoBERTa-initialized encoder-decoder
models were excluded for MT. GPT2-initialized models were initialized
from the gpt2 pre-trained checkpoint as within the previous experiment.
The interpretation results are reported using the BLUE-4 rating metric .
| Model | En De (BLEU-4) | De En (BLEU-4) |
|---|---|---|
| Rnd2Rnd | 26.0 | 29.1 |
| Rnd2BERT | 27.2 | 30.4 |
| BERT2Rnd | 30.1 | 32.7 |
| Rnd2GPT2 | 19.6 | 23.2 |
| BERT2BERT | 30.1 | 32.7 |
| BERTShare | 29.6 | 32.6 |
| BERT2GPT2 | 23.2 | 31.4 |
| — | — | — |
| BERT2Rnd (large, custom) | 31.7 | 34.2 |
| BERTShare (large, custom) | 30.5 | 33.8 |
Again, we observe a big performance boost by warm-starting the
encoder-part, with BERT2Rnd and BERT2BERT yielding the perfect results
on each the En De and De En tasks. GPT2
initialized models perform significantly worse even than the Rnd2Rnd
baseline on En De. Bearing in mind that the gpt2
checkpoint was trained only on English text, it will not be very surprising
that BERT2GPT2 and Rnd2GPT2 models have difficulties generating
German translations. This hypothesis is supported by the competitive
results (e.g. 31.4 vs. 32.7) of BERT2GPT2 on the De En
task for which GPT2’s vocabulary matches the English output format.
Contrary to the outcomes obtained on sentence fusion and sentence
splitting, sharing encoder and decoder weight parameters doesn’t yield
a performance boost in MT. Possible reasons for this as stated by the
authors include
- the encoder-decoder model capability is a vital think about MT,
and - the encoder and decoder must cope with different grammar and
vocabulary
For the reason that bert-base-multilingual-cased checkpoint was trained on more
than 100 languages, its vocabulary might be undesirably large for
En De and De En MT. Thus, the authors pre-trained a
large BERT encoder-only checkpoint on the English and German subset of
the Wikipedia dump and subsequently used it to warm-start a BERT2Rnd
and BERTShare encoder-decoder model. Because of the improved
vocabulary, one other significant performance boost is observed, with
BERT2Rnd (large, custom) significantly outperforming all other models.
Summarization (CNN/Dailymail, BBC XSum, Gigaword)
Finally, the encoder-decoder models were evaluated on the arguably most
difficult sequence-to-sequence task – summarization. The authors
picked three summarization datasets with different characteristics for
evaluation: Gigaword (headline generation), BBC XSum (extreme
summarization), and CNN/Dailymayl (abstractive summarization).
The Gigaword dataset incorporates sentence-level abstractive summarizations,
requiring the model to learn sentence-level understanding, abstraction,
and eventually paraphrasing. A typical data sample in Gigaword, reminiscent of
“*venezuelan president hugo chavez said thursday he has ordered a probe
right into a suspected coup plot allegedly involving energetic and retired
military officers .*”,
would have a corresponding headline as its label, e.g.:
“chavez orders probe into suspected coup plot“.
The BBC XSum dataset consists of for much longer article-like text inputs
with the labels being mostly single sentence summarizations. This
dataset requires the model not only to learn document-level inference
but additionally a high level of abstractive paraphrasing. Some data samples of
the BBC XSUM datasets are shown
here.
For the CNN/Dailmail dataset, documents, that are of comparable length
than those within the BBC XSum dataset, must be summarized to
bullet-point story highlights. The labels due to this fact often consist of
multiple sentences. Besides document-level understanding, the
CNN/Dailymail dataset requires models to be good at copying probably the most
salient information. Some examples could be viewed
here.
The models are evaluated using the Rouge
metric, whereas the Rouge-2
scores are shown below.
Alright, let’s take a take a look at the outcomes.
| Model | CNN/Dailymail (Rouge-2) | BBC XSum (Rouge-2) | Gigaword (Rouge-2) |
|---|---|---|---|
| Rnd2Rnd | 14.00 | 10.23 | 18.71 |
| Rnd2BERT | 15.55 | 11.52 | 18.91 |
| BERT2Rnd | 17.76 | 15.83 | 19.26 |
| Rnd2GPT2 | 8.81 | 8.77 | 18.39 |
| BERT2BERT | 17.84 | 15.24 | 19.68 |
| BERTShare | 18.10 | 16.12 | 19.81 |
| RoBERTaShare | 18.95 | 17.50 | 19.70 |
| BERT2GPT2 | 4.96 | 8.37 | 18.23 |
| RoBERTa2GPT2 | 14.72 | 5.20 | 19.21 |
| — | — | — | — |
| RoBERTaShare (large) | 18.91 | 18.79 | 19.78 |
We observe again that warm-starting the encoder-part gives a big
improvement over models with randomly-initialized encoders, which is
especially visible for document-level abstraction tasks, i.e.
CNN/Dailymail and BBC XSum. This shows that tasks requiring a high level
of abstraction profit more from a pre-trained encoder part than those
requiring only sentence-level abstraction. Aside from Gigaword
GPT2-based encoder-decoder models appear to be unfit for summarization.
Moreover, the shared encoder-decoder models are the perfect performing
models for summarization. RoBERTaShare and BERTShare are the perfect
performing models on all datasets whereas the margin is very
significant on the BBC XSum dataset on which RoBERTaShare (large)
outperforms BERT2BERT and BERT2Rnd by ca. 3 Rouge-2 points and
Rnd2Rnd by greater than 8 Rouge-2 points. As brought forward by the
authors, “this might be since the BBC summary sentences follow a
distribution that is comparable to that of the sentences within the document,
whereas this will not be necessarily the case for the Gigaword headlines and
the CNN/DailyMail bullet-point highlights“. Intuitively this implies
that in BBC XSum, the input sentences processed by the encoder are very
similar in structure to the only sentence summary processed by the
decoder, i.e. same length, similar selection of words, similar syntax.
Conclusion
Alright, let’s draw a conclusion and check out to derive some practical suggestions.
-
We’ve got observed on all tasks that a warm-started encoder-part gives
a big performance boost in comparison with encoder-decoder models
having a randomly initialized encoder. Alternatively,
warm-starting the decoder appears to be less essential, with
BERT2BERT being on par with BERT2Rnd on most tasks. An intuitive
reason could be that since a BERT- or RoBERTa-initialized encoder
part has none of its weight parameters randomly initialized, the
encoder can fully exploit the acquired knowledge of BERT’s or
RoBERTa’s pre-trained checkpoints, respectively. In contrast, the
warm-started decoder all the time has parts of its weight parameters
randomly initialized which possibly makes it much harder to
effectively leverage the knowledge acquired by the checkpoint used
to initialize the decoder. -
Next, we noticed that it is commonly helpful to share encoder and
decoder weights, especially if the goal distribution is comparable to
the input distribution (e.g. BBC XSum). Nonetheless, for datasets
whose goal data distribution differs more significantly from the
input data distribution and for which model capability is understood
to play a vital role, e.g. WMT14, encoder-decoder weight
sharing appears to be disadvantageous. -
Finally, we have now seen that it is rather essential that the vocabulary
of the pre-trained “stand-alone” checkpoints fit the vocabulary
required to unravel the sequence-to-sequence task. E.g. a
warm-started BERT2GPT2 encoder-decoder will perform poorly on En
De MT because GPT2 was pre-trained on English whereas the
goal language is German. The general poor performance of the
BERT2GPT2, Rnd2GPT2, and RoBERTa2GPT2 in comparison with BERT2BERT,
BERTShared, and RoBERTaShared suggests that it’s more practical
to have a shared vocabulary. Also, it shows that initializing the
decoder part with a pre-trained GPT2 checkpoint is not more
effective than initializing it with a pre-trained BERT checkpoint
besides GPT2 being more just like the decoder in its architecture.
For every of the above tasks, probably the most performant models were ported to
🤗Transformers and could be accessed here:
To retrieve BLEU-4 scores, a script from the Tensorflow Official
Transformer implementation https://github.com/tensorflow/models/tree
master/official/nlp/transformer was used. Note that, otherwise from
the tensor2tensor/utils/ get_ende_bleu.sh utilized by Vaswani et al.
(2017), this script doesn’t split noun compounds, but utf-8 quotes were
normalized to ascii quotes after having noted that the pre-processed
training set incorporates only ascii quotes.
Model capability is an off-the-cuff definition of how good the model is
at modeling complex patterns. It is usually sometimes defined as the
ability of a model to learn from increasingly data. Model capability is
broadly measured by the variety of trainable parameters – the more
parameters, the upper the model capability.
Warm-starting encoder-decoder models with 🤗Transformers (Practice)
We’ve got explained the speculation of warm-starting encoder-decoder models,
analyzed empirical results on multiple datasets, and have derived
practical conclusions. Let’s now walk through an entire code example
showcasing how a BERT2BERT model could be warm-started and
consequently fine-tuned on the CNN/Dailymail summarization task. We
shall be leveraging the 🤗datasets and 🤗Transformers libraries.
As well as, the next list provides a condensed version of this and
other notebooks on warm-starting other combos of encoder-decoder
models.
- for BERT2BERT on CNN/Dailymail (a condensed version of this
notebook), click
here. - for RoBERTaShare on BBC XSum, click
here. - for BERT2Rnd on WMT14 En De, click here.
- for RoBERTa2GPT2 on DiscoFuse, click here.
Note: This notebook only uses a couple of training, validation, and test
data samples for demonstration purposes. To fine-tune an encoder-decoder
model on the complete training data, the user should change the training and
data preprocessing parameters accordingly as highlighted by the
comments.
Data Preprocessing
On this section, we show how the info could be pre-processed for training.
More importantly, we try to present the reader some insight into the
technique of deciding the best way to preprocess the info.
We’ll need datasets and transformers to be installed.
!pip install datasets==1.0.2
!pip install transformers==4.2.1
Let’s start by downloading the CNN/Dailymail dataset.
import datasets
train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
Alright, let’s get a primary impression of the dataset. Alternatively,
the dataset may also be visualized using the awesome datasets
viewer
online.
train_data.info.description
Our input is named article and our labels are called highlights.
Let’s now print out the primary example of the training data to get a
feeling for the info.
import pandas as pd
from IPython.display import display, HTML
from datasets import ClassLabel
df = pd.DataFrame(train_data[:1])
del df["id"]
for column, typ in train_data.features.items():
if isinstance(typ, ClassLabel):
df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
OUTPUT:
-------
Article:
"""It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to make use of military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military motion against Syrian targets is the best step to take over the alleged use of chemical weapons. The proposed laws from Obama asks Congress to approve the usage of military force "to discourage, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It is a step that is about to show a global crisis right into a fierce domestic political battle. There are key questions looming over the controversy: What did U.N. weapons inspectors find in Syria? What happens if Congress votes no? And the way will the Syrian government react? In a televised address from the White House Rose Garden earlier Saturday, the president said he would take his case to Congress, not because he has to -- but because he desires to. "While I imagine I actually have the authority to perform this military motion without specific congressional authorization, I do know that the country shall be stronger if we take this course, and our actions shall be even more practical," he said. "We should always have this debate, because the problems are too big for business as usual." Obama said top congressional leaders had agreed to schedule a debate when the body returns to Washington on September 9. The Senate Foreign Relations Committee will hold a hearing over the matter on Tuesday, Sen. Robert Menendez said. Transcript: Read Obama's full remarks . Syrian crisis: Latest developments . U.N. inspectors leave Syria . Obama's remarks got here shortly after U.N. inspectors left Syria, carrying evidence that can determine whether chemical weapons were utilized in an attack early last week in a Damascus suburb. "The aim of the sport here, the mandate, could be very clear -- and that's to determine whether chemical weapons were used -- and never by whom," U.N. spokesman Martin Nesirky told reporters on Saturday. But who used the weapons within the reported toxic gas attack in a Damascus suburb on August 21 has been a key point of world debate over the Syrian crisis. Top U.S. officials have said there isn't any doubt that the Syrian government was behind it, while Syrian officials have denied responsibility and blamed jihadists fighting with the rebels. British and U.S. intelligence reports say the attack involved chemical weapons, but U.N. officials have stressed the importance of waiting for an official report from inspectors. The inspectors will share their findings with U.N. Secretary-General Ban Ki-moon Ban, who has said he desires to wait until the U.N. team's final report is accomplished before presenting it to the U.N. Security Council. The Organization for the Prohibition of Chemical Weapons, which nine of the inspectors belong to, said Saturday that it could take up to 3 weeks to investigate the evidence they collected. "It needs time to have the ability to investigate the knowledge and the samples," Nesirky said. He noted that Ban has repeatedly said there isn't any alternative to a political solution to the crisis in Syria, and that "a military solution will not be an option." Bergen: Syria is an issue from hell for the U.S. Obama: 'This menace have to be confronted' Obama's senior advisers have debated the following steps to take, and the president's comments Saturday got here amid mounting political pressure over the situation in Syria. Some U.S. lawmakers have called for immediate motion while others warn of moving into what could turn into a quagmire. Some global leaders have expressed support, however the British Parliament's vote against military motion earlier this week was a blow to Obama's hopes of getting strong backing from key NATO allies. On Saturday, Obama proposed what he said could be a limited military motion against Syrian President Bashar al-Assad. Any military attack wouldn't be open-ended or include U.S. ground forces, he said. Syria's alleged use of chemical weapons earlier this month "is an assault on human dignity," the president said. A failure to reply with force, Obama argued, "could lead on to escalating use of chemical weapons or their proliferation to terrorist groups who would do our people harm. In a world with many dangers, this menace have to be confronted." Syria missile strike: What would occur next? Map: U.S. and allied assets around Syria . Obama decision got here Friday night . On Friday night, the president made a last-minute decision to seek the advice of lawmakers. What is going to occur in the event that they vote no? It's unclear. A senior administration official told CNN that Obama has the authority to act without Congress -- even when Congress rejects his request for authorization to make use of force. Obama on Saturday continued to shore up support for a strike on the al-Assad government. He spoke by phone with French President Francois Hollande before his Rose Garden speech. "The 2 leaders agreed that the international community must deliver a resolute message to the Assad regime -- and others who would think about using chemical weapons -- that these crimes are unacceptable and those that violate this international norm shall be held accountable by the world," the White House said. Meanwhile, as uncertainty loomed over how Congress would weigh in, U.S. military officials said they remained on the ready. 5 key assertions: U.S. intelligence report on Syria . Syria: Who wants what after chemical weapons horror . Reactions mixed to Obama's speech . A spokesman for the Syrian National Coalition said that the opposition group was upset by Obama's announcement. "Our fear now could be that the dearth of motion could embolden the regime and so they repeat his attacks in a more serious way," said spokesman Louay Safi. "So we're quite concerned." Some members of Congress applauded Obama's decision. House Speaker John Boehner, Majority Leader Eric Cantor, Majority Whip Kevin McCarthy and Conference Chair Cathy McMorris Rodgers issued a press release Saturday praising the president. "Under the Structure, the responsibility to declare war lies with Congress," the Republican lawmakers said. "We're glad the president is looking for authorization for any military motion in Syria in response to serious, substantive questions being raised." Greater than 160 legislators, including 63 of Obama's fellow Democrats, had signed letters calling for either a vote or not less than a "full debate" before any U.S. motion. British Prime Minister David Cameron, whose own try to get lawmakers in his country to support military motion in Syria failed earlier this week, responded to Obama's speech in a Twitter post Saturday. "I understand and support Barack Obama's position on Syria," Cameron said. An influential lawmaker in Russia -- which has stood by Syria and criticized america -- had his own theory. "The primary reason Obama is popping to the Congress: the military operation didn't get enough support either on this planet, amongst allies of the US or in america itself," Alexei Pushkov, chairman of the international-affairs committee of the Russian State Duma, said in a Twitter post. In america, scattered groups of anti-war protesters across the country took to the streets Saturday. "Like many other Americans...we're just bored with america getting involved and invading and bombing other countries," said Robin Rosecrans, who was amongst tons of at a Los Angeles demonstration. What do Syria's neighbors think? Why Russia, China, Iran stand by Assad . Syria's government unfazed . After Obama's speech, a military and political analyst on Syrian state TV said Obama is "embarrassed" that Russia opposes military motion against Syria, is "crying for help" for somebody to come back to his rescue and is facing two defeats -- on the political and military levels. Syria's prime minister appeared unfazed by the saber-rattling. "The Syrian Army's status is on maximum readiness and fingers are on the trigger to confront all challenges," Wael Nader al-Halqi said during a gathering with a delegation of Syrian expatriates from Italy, in response to a banner on Syria State TV that was broadcast prior to Obama's address. An anchor on Syrian state television said Obama "gave the impression to be preparing for an aggression on Syria based on repeated lies." A top Syrian diplomat told the state television network that Obama was facing pressure to take military motion from Israel, Turkey, some Arabs and right-wing extremists in america. "I believe he has done well by doing what Cameron did by way of taking the difficulty to Parliament," said Bashar Jaafari, Syria's ambassador to the United Nations. Each Obama and Cameron, he said, "climbed to the highest of the tree and do not know the best way to get down." The Syrian government has denied that it used chemical weapons within the August 21 attack, saying that jihadists fighting with the rebels used them in an effort to show global sentiments against it. British intelligence had put the number of individuals killed within the attack at greater than 350. On Saturday, Obama said "all told, well over 1,000 people were murdered." U.S. Secretary of State John Kerry on Friday cited a death toll of 1,429, greater than 400 of them children. No explanation was offered for the discrepancy. Iran: U.S. military motion in Syria would spark 'disaster' Opinion: Why strikes in Syria are a nasty idea ."""
Summary:
"""Syrian official: Obama climbed to the highest of the tree, "doesn't know the best way to get down"nObama sends a letter to the heads of the House and Senate .nObama to hunt congressional approval on military motion against Syria .nAim is to find out whether CW were used, not by whom, says U.N. spokesman"""
The input data seems to consist of short news articles. Interestingly,
the labels look like bullet-point-like summaries. At this point, one
should probably take a take a look at a few other examples to get a
higher feeling for the info.
One must also notice here that the text is case-sensitive. This
signifies that we have now to watch out if we wish to make use of case-insensitive
models. As CNN/Dailymail is a summarization dataset, the model shall be
evaluated using the ROUGE metric. Checking the outline of ROUGE
in 🤗datasets, cf. here, we will
see that the metric is case-insensitive, meaning that upper case
letters shall be normalized to lower case letters during evaluation.
Thus, we will safely leverage uncased checkpoints, reminiscent of
bert-base-uncased.
Cool! Next, let’s get a way of the length of input data and labels.
As models compute length in token-length, we’ll make use of the
bert-base-uncased tokenizer to compute the article and summary length.
First, we load the tokenizer.
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Next, we make use of .map() to compute the length of the article and
its summary. Since we all know that the utmost length that
bert-base-uncased can process amounts to 512, we’re also interested
in the proportion of input samples being longer than the utmost length.
Similarly, we compute the proportion of summaries which might be longer than
64, and 128 respectively.
We will define the .map() function as follows.
def map_to_length(x):
x["article_len"] = len(tokenizer(x["article"]).input_ids)
x["article_longer_512"] = int(x["article_len"] > 512)
x["summary_len"] = len(tokenizer(x["highlights"]).input_ids)
x["summary_longer_64"] = int(x["summary_len"] > 64)
x["summary_longer_128"] = int(x["summary_len"] > 128)
return x
It needs to be sufficient to have a look at the primary 10000 samples. We will speed
up the mapping through the use of multiple processes with num_proc=4.
sample_size = 10000
data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)
Having computed the length for the primary 10000 samples, we should always now
average them together. For this, we will make use of the .map()
function with batched=True and batch_size=-1 to have access to all
10000 samples throughout the .map() function.
def compute_and_print_stats(x):
if len(x["article_len"]) == sample_size:
print(
"Article Mean: {}, %-Articles > 512:{}, Summary Mean:{}, %-Summary > 64:{}, %-Summary > 128:{}".format(
sum(x["article_len"]) / sample_size,
sum(x["article_longer_512"]) / sample_size,
sum(x["summary_len"]) / sample_size,
sum(x["summary_longer_64"]) / sample_size,
sum(x["summary_longer_128"]) / sample_size,
)
)
output = data_stats.map(
compute_and_print_stats,
batched=True,
batch_size=-1,
)
OUTPUT:
-------
Article Mean: 847.6216, %-Articles > 512:0.7355, Summary Mean:57.7742, %-Summary > 64:0.3185, %-Summary > 128:0.0
We will see that on average an article incorporates 848 tokens with ca. 3/4
of the articles being longer than the model’s max_length 512. The
summary is on average 57 tokens long. Over 30% of our 10000-sample
summaries are longer than 64 tokens, but none are longer than 128
tokens.
bert-base-cased is restricted to 512 tokens, which implies we’d must
cut possibly essential information from the article. Because many of the
essential information is commonly found at the start of articles and
because we wish to be computationally efficient, we determine to follow
bert-base-cased with a max_length of 512 on this notebook. This
selection will not be optimal but has shown to yield good
results on CNN/Dailymail.
Alternatively, one could leverage long-range sequence models, reminiscent of
Longformer to be
used because the encoder.
Regarding the summary length, we will see that a length of 128 already
includes the entire summary labels. 128 is definitely throughout the limits of
bert-base-cased, so we determine to limit the generation to 128.
Again, we’ll make use of the .map() function – this time to
transform each training batch right into a batch of model inputs.
"article" and "highlights" are tokenized and ready because the
Encoder’s "input_ids" and Decoder’s "decoder_input_ids"
respectively.
"labels" are shifted robotically to the left for language modeling
training.
Lastly, it is rather essential to recollect to disregard the lack of the
padded labels. In 🤗Transformers this could be done by setting the label to
-100. Great, let’s write down our mapping function then.
encoder_max_length=512
decoder_max_length=128
def process_data_to_model_inputs(batch):
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
return batch
On this notebook, we train and evaluate the model just on a couple of training
examples for demonstration and set the batch_size to 4 to stop
out-of-memory issues.
The next line reduces the training data to only the primary 32
examples. The cell could be commented out or not run for a full training
run. Good results were obtained with a batch_size of 16.
train_data = train_data.select(range(32))
Alright, let’s prepare the training data.
batch_size=4
train_data = train_data.map(
process_data_to_model_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights", "id"]
)
Taking a take a look at the processed training dataset we will see that the
column names article, highlights, and id have been replaced by the
arguments expected by the EncoderDecoderModel.
train_data
OUTPUT:
-------
Dataset(features: {'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}, num_rows: 32)
Thus far, the info was manipulated using Python’s List format. Let’s
convert the info to PyTorch Tensors to be trained on GPU.
train_data.set_format(
type="torch", columns=["input_ids", "attention_mask", "labels"],
)
Awesome, the info processing of the training data is finished.
Analogous, we will do the identical for the validation data.
First, we load 10% of the validation dataset:
val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")
For demonstration purposes, the validation data is then reduced to simply
8 samples,
val_data = val_data.select(range(8))
the mapping function is applied,
val_data = val_data.map(
process_data_to_model_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights", "id"]
)
and, finally, the validation data can also be converted to PyTorch tensors.
val_data.set_format(
type="torch", columns=["input_ids", "attention_mask", "labels"],
)
Great! Now we will move to warm-starting the EncoderDecoderModel.
Warm-starting the Encoder-Decoder Model
This section explains how an Encoder-Decoder model could be warm-started
using the bert-base-cased checkpoint.
Let’s start by importing the EncoderDecoderModel. For more detailed
information in regards to the EncoderDecoderModel class, the reader is suggested
to try the
documentation.
from transformers import EncoderDecoderModel
In contrast to other model classes in 🤗Transformers, the
EncoderDecoderModel class has two methods to load pre-trained weights,
namely:
-
the “standard”
.from_pretrained(...)method is derived from the
generalPretrainedModel.from_pretrained(...)method and thus
corresponds exactly to the certainly one of other model classes. The
function expects a single model identifier, e.g.
.from_pretrained("google/bert2bert_L-24_wmt_de_en")and can load
a single.ptcheckpoint file into theEncoderDecoderModelclass. -
a special
.from_encoder_decoder_pretrained(...)method, which may
be used to warm-start an encoder-decoder model from two model
identifiers – one for the encoder and one for the decoder. The primary
model identifier is thereby used to load the encoder, via
AutoModel.from_pretrained(...)(see doc
here)
and the second model identifier is used to load the decoder via
AutoModelForCausalLM(see doc
here.
Alright, let’s warm-start our BERT2BERT model. As mentioned earlier
we’ll warm-start each the encoder and decoder with the
"bert-base-cased" checkpoint.
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
OUTPUT:
-------
"""Some weights of the model checkpoint at bert-base-uncased weren't used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you happen to are initializing BertLMHeadModel from the checkpoint of a model trained on one other task or with one other architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you happen to are initializing BertLMHeadModel from the checkpoint of a model that you just expect to be exactly equivalent (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel weren't initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias']"""
You need to probably TRAIN this model on a down-stream task to have the ability to make use of it for predictions and inference."""
For once, we should always take take a look at the warning here. We will see
that two weights corresponding to a "cls" layer weren’t used. This
mustn’t be an issue because we do not need BERT’s CLS layer for
sequence-to-sequence tasks. Also, we notice that a number of weights are
“newly” or randomly initialized. When taking a more in-depth look these
weights all correspond to the cross-attention layer, which is strictly
what we’d expect after having read the speculation above.
Let’s take a more in-depth take a look at the model.
bert2bert
OUTPUT:
-------
EncoderDecoderModel(
(encoder): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
),
...
,
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(decoder): BertLMHeadModel(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(crossattention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
),
...,
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(crossattention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(cls): BertOnlyMLMHead(
(predictions): BertLMPredictionHead(
(transform): BertPredictionHeadTransform(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(decoder): Linear(in_features=768, out_features=30522, bias=True)
)
)
)
)
We see that bert2bert.encoder is an instance of BertModel and that
bert2bert.decoder certainly one of BertLMHeadModel. Nonetheless, each instances
at the moment are combined right into a single torch.nn.Module and may thus be saved
as a single .pt checkpoint file.
Let’s try it out using the usual .save_pretrained(...) method.
bert2bert.save_pretrained("bert2bert")
Similarly, the model could be reloaded using the usual
.from_pretrained(...) method.
bert2bert = EncoderDecoderModel.from_pretrained("bert2bert")
Awesome. Let’s also checkpoint the config.
bert2bert.config
OUTPUT:
-------
EncoderDecoderConfig {
"_name_or_path": "bert2bert",
"architectures": [
"EncoderDecoderModel"
],
"decoder": {
"_name_or_path": "bert-base-uncased",
"add_cross_attention": true,
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bad_words_ids": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"decoder_start_token_id": null,
"do_sample": false,
"early_stopping": false,
"eos_token_id": null,
"finetuning_task": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-12,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 512,
"min_length": 0,
"model_type": "bert",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"pad_token_id": 0,
"prefix": null,
"pruned_heads": {},
"repetition_penalty": 1.0,
"return_dict": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torchscript": false,
"type_vocab_size": 2,
"use_bfloat16": false,
"use_cache": true,
"vocab_size": 30522,
"xla_device": null
},
"encoder": {
"_name_or_path": "bert-base-uncased",
"add_cross_attention": false,
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bad_words_ids": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"decoder_start_token_id": null,
"do_sample": false,
"early_stopping": false,
"eos_token_id": null,
"finetuning_task": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-12,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 512,
"min_length": 0,
"model_type": "bert",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"pad_token_id": 0,
"prefix": null,
"pruned_heads": {},
"repetition_penalty": 1.0,
"return_dict": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torchscript": false,
"type_vocab_size": 2,
"use_bfloat16": false,
"use_cache": true,
"vocab_size": 30522,
"xla_device": null
},
"is_encoder_decoder": true,
"model_type": "encoder_decoder"
}
The config is similarly composed of an encoder config and a decoder
config each of that are instances of BertConfig in our case. Nonetheless,
the general config is of type EncoderDecoderConfig and is due to this fact
saved as a single .json file.
In conclusion, one should do not forget that once an EncoderDecoderModel
object is instantiated, it provides the identical functionality as another
Encoder-Decoder model in 🤗Transformers, e.g.
BART,
T5,
ProphetNet,
… The one difference is that an EncoderDecoderModel provides the
additional from_encoder_decoder_pretrained(...) function allowing the
model class to be warm-started from any two encoder and decoder
checkpoints.
On a side-note, if one would need to create a shared encoder-decoder
model, the parameter tie_encoder_decoder=True can moreover be
passed as follows:
shared_bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased", tie_encoder_decoder=True)
As a comparison, we will see that the tied model has much fewer
parameters as expected.
print(f"nnNum Params. Shared: {shared_bert2bert.num_parameters()}, Non-Shared: {bert2bert.num_parameters()}")
OUTPUT:
-------
Num Params. Shared: 137298244, Non-Shared: 247363386
On this notebook, we’ll nonetheless train a non-shared Bert2Bert model,
so we proceed with bert2bert and never shared_bert2bert.
del shared_bert2bert
We’ve got warm-started a bert2bert model, but we have now not defined all
the relevant parameters used for beam search decoding yet.
Let’s start by setting the special tokens. bert-base-cased doesn’t
have a decoder_start_token_id or eos_token_id, so we’ll use its
cls_token_id and sep_token_id respectively. Also, we should always define a
pad_token_id on the config and make sure that the right vocab_size is
set.
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
Next, let’s define all parameters related to beam search decoding.
Since bart-large-cnn yields good results on CNN/Dailymail, we’ll
just copy its beam search decoding parameters.
For more details on what each of those parameters does, please take a
take a look at this blog post or
the
docs.
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4
Alright, let’s now start fine-tuning the warm-started BERT2BERT
model.
Advantageous-Tuning Warm-Began Encoder-Decoder Models
On this section, we’ll show how one could make use of the
Seq2SeqTrainer to fine-tune a warm-started encoder-decoder model.
Let’s first import the Seq2SeqTrainer and its training arguments
Seq2SeqTrainingArguments.
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
As well as, we want a few python packages to make the
Seq2SeqTrainer work.
!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu
The Seq2SeqTrainer extends 🤗Transformer’s Trainer for encoder-decoder
models. Briefly, it allows using the generate(...) function during
evaluation, which is crucial to validate the performance of
encoder-decoder models on most sequence-to-sequence tasks, reminiscent of
summarization.
For more information on the Trainer, one should read through
this short
tutorial.
Let’s begin by configuring the Seq2SeqTrainingArguments.
The argument predict_with_generate needs to be set to True, in order that
the Seq2SeqTrainer runs the generate(...) on the validation data and
passes the generated output as predictions to the
compute_metric(...) function which we’ll define later. The
additional arguments are derived from TrainingArguments and could be
read upon
here.
For a whole training run, one should change those arguments as
needed. Good default values are commented out below.
For more information on the Seq2SeqTrainer, the reader is suggested to
take a take a look at the
code.
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=True,
output_dir="./",
logging_steps=2,
save_steps=10,
eval_steps=4,
)
Also, we want to define a function to accurately compute the ROUGE rating
during validation. Since we activated predict_with_generate, the
compute_metrics(...) function expects predictions that were obtained
using the generate(...) function. Like most summarization tasks,
CNN/Dailymail is usually evaluated using the ROUGE rating.
Let’s first load the ROUGE metric using the 🤗datasets library.
rouge = datasets.load_metric("rouge")
Next, we’ll define the compute_metrics(...) function. The rouge
metric computes the rating from two lists of strings. Thus we decode each
the predictions and labels – ensuring that -100 is accurately
replaced by the pad_token_id and take away all special characters by
setting skip_special_tokens=True.
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
Great, now we will pass all arguments to the Seq2SeqTrainer and begin
finetuning. Executing the next cell will take ca. 10 minutes ☕.
Finetuning BERT2BERT on the entire CNN/Dailymail training data
takes ca. model takes ca. 8h on a single TITAN RTX GPU.
trainer = Seq2SeqTrainer(
model=bert2bert,
tokenizer=tokenizer,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_data,
eval_dataset=val_data,
)
trainer.train()
Awesome, we should always now be fully equipped to finetune a warm-started
encoder-decoder model. To examine the results of our fine-tuning let’s
take a take a look at the saved checkpoints.
!ls
OUTPUT:
-------
bert2bert checkpoint-20 runs seq2seq_trainer.py
checkpoint-10 __pycache__ sample_data seq2seq_training_args.py
Finally, we will load the checkpoint as usual via the
EncoderDecoderModel.from_pretrained(...) method.
dummy_bert2bert = EncoderDecoderModel.from_pretrained("./checkpoint-20")
Evaluation
In a final step, we’d want to judge the BERT2BERT model on the
test data.
To begin, as a substitute of loading the dummy model, let’s load a BERT2BERT
model that was finetuned on the complete training dataset. Also, we load its
tokenizer, which is just a replica of bert-base-cased‘s tokenizer.
from transformers import BertTokenizer
bert2bert = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail").to("cuda")
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
Next, we load just 2% of CNN/Dailymail’s test data. For the complete
evaluation, one should obviously use 100% of the info.
test_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="test[:2%]")
Now, we will again leverage 🤗dataset’s handy map() function to
generate a summary for every test sample.
For every data sample we:
- first, tokenize the
"article", - second, generate the output token ids, and
- third, decode the output token ids to acquire our predicted summary.
def generate_summary(batch):
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to("cuda")
attention_mask = inputs.attention_mask.to("cuda")
outputs = bert2bert.generate(input_ids, attention_mask=attention_mask)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch["pred_summary"] = output_str
return batch
Let’s run the map function to acquire the results dictionary that has
the model’s predicted summary stored for every sample. Executing the
following cell may take ca. 10min ☕.
batch_size = 16
results = test_data.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
Finally, we compute the ROUGE rating.
rouge.compute(predictions=results["pred_summary"], references=results["highlights"], rouge_types=["rouge2"])["rouge2"].mid
OUTPUT:
-------
Rating(precision=0.10389454113300968, recall=0.1564771201053348, fmeasure=0.12175271663717585)
That is it. We have shown the best way to warm-start a BERT2BERT model and
fine-tune/evaluate it on the CNN/Dailymail dataset.
The fully trained BERT2BERT model is uploaded to the 🤗model hub under
patrickvonplaten/bert2bert_cnn_daily_mail.
The model achieves a ROUGE-2 rating of 18.22 on the complete evaluation
data, which is even slightly higher than reported within the paper.
For some summarization examples, the reader is suggested to make use of the net
inference API of the model,
here.
Thanks lots to Sascha Rothe, Shashi Narayan, and Aliaksei Severyn from
Google Research, and Victor Sanh, Sylvain Gugger, and Thomas Wolf from
🤗Hugging Face for proof-reading and giving very much appreciated
feedback.






