Transformer-based models have shown to be very useful for a lot of NLP tasks. Nevertheless, a serious limitation of transformers-based models is its time & memory complexity (where is sequence length). Hence, it’s computationally very expensive to use transformer-based models on long sequences . Several recent papers, e.g. Longformer, Performer, Reformer, Clustered attention attempt to treatment this problem by approximating the total attention matrix. You possibly can checkout 🤗’s recent blog post in case you might be unfamiliar with these models.
BigBird (introduced in paper) is one in every of such recent models to handle this issue. BigBird relies on block sparse attention as a substitute of normal attention (i.e. BERT’s attention) and might handle sequences as much as a length of 4096 at a much lower computational cost in comparison with BERT. It has achieved SOTA on various tasks involving very long sequences akin to long documents summarization, question-answering with long contexts.
BigBird RoBERTa-like model is now available in 🤗Transformers. The goal of this post is to offer the reader an in-depth understanding of massive bird implementation & ease one’s life in using BigBird with 🤗Transformers. But, before going into more depth, it is necessary to keep in mind that the BigBird's attention is an approximation of BERT‘s full attention and due to this fact doesn’t strive to be higher than BERT's full attention, but somewhat to be more efficient. It simply allows to use transformer-based models to for much longer sequences since BERT’s quadratic memory requirement quickly becomes unbearable. Simply put, if we might have compute & time, BERT’s attention can be preferred over block sparse attention (which we’re going to discuss on this post).
For those who wonder why we want more compute when working with longer sequences, this blog post is just best for you!
A few of the predominant questions one might need when working with standard BERT-like attention include:
- Do all tokens really have to take care of all other tokens?
- Why not compute attention only over necessary tokens?
- How you can determine what tokens are necessary?
- How you can attend to simply just a few tokens in a really efficient way?
On this blog post, we’ll try to reply those questions.
What tokens must be attended to?
We’ll give a practical example of how attention works by considering the sentence “BigBird is now available in HuggingFace for extractive query answering”.
In BERT-like attention, every word would simply attend to all other tokens. Put mathematically, this may mean that every queried token ,
would attend to the total list of .
Let’s take into consideration a good choice of key tokens that a queried token actually only should attend to by writing some pseudo-code.
Will will assume that the token available is queried and construct a smart list of key tokens to take care of.
>>>
>>> example = ['BigBird', 'is', 'now', 'available', 'in', 'HuggingFace', 'for', 'extractive', 'question', 'answering']
>>>
>>> query_token = 'available'
>>>
>>> key_tokens = []
Nearby tokens must be necessary because, in a sentence (sequence of words), the present word is extremely depending on neighboring past & future tokens. This intuition is the thought behind the concept of sliding attention.
>>>
>>>
>>> sliding_tokens = ["now", "available", "in"]
>>>
>>> key_tokens.append(sliding_tokens)
Long-range dependencies: For some tasks, it’s crucial to capture long-range relationships between tokens. E.g., in `question-answering the model needs to check each token of the context to the entire query to give you the chance to determine which a part of the context is beneficial for an accurate answer. If many of the context tokens would just attend to other context tokens, but to not the query, it becomes much harder for the model to filter necessary context tokens from less necessary context tokens.
Now, BigBird proposes two ways of allowing long-term attention dependencies while staying computationally efficient.
- Global tokens: Introduce some tokens which is able to attend to each token and that are attended by every token. Eg: “HuggingFace is constructing nice libraries for simple NLP”. Now, for instance ‘constructing’ is defined as a world token, and the model must know the relation amongst ‘NLP’ & ‘HuggingFace’ for some task (Note: these 2 tokens are at two extremes); Now having ‘constructing’ attend globally to all other tokens will probably help the model to associate ‘NLP’ with ‘HuggingFace’.
>>>
>>> global_tokens = ["BigBird", "answering"]
>>>
>>> key_tokens.append(global_tokens)
- Random tokens: Select some tokens randomly which is able to transfer information by transferring to other tokens which in turn can transfer to other tokens. This will likely reduce the fee of data travel from one token to other.
>>>
>>>
>>> random_tokens = ["is"]
>>>
>>> key_tokens.append(random_tokens)
>>>
>>> key_tokens
{'now', 'is', 'in', 'answering', 'available', 'BigBird'}
This fashion, the query token attends only to a subset of all possible tokens while yielding a superb approximation of full attention. The identical approach will is used for all other queried tokens. But remember, the entire point here is to approximate BERT‘s full attention as efficiently as possible. Simply making each queried token attend all key tokens because it’s done for BERT could be computed very effectively as a sequence of matrix multiplication on modern hardware, like GPUs. Nevertheless, a mixture of sliding, global & random attention appears to imply sparse matrix multiplication, which is harder to implement efficiently on modern hardware.
Certainly one of the foremost contributions of BigBird is the proposition of a block sparse attention mechanism that permits computing sliding, global & random attention effectively. Let’s look into it!
Understanding the necessity for global, sliding, random keys with Graphs
First, let’s get a greater understanding of global, sliding & random attention using graphs and take a look at to grasp how the mix of those three attention mechanisms yields a superb approximation of ordinary Bert-like attention.


The above figure shows global (left), sliding (middle) & random (right) connections respectively as a graph. Each node corresponds to a token and every line represents an attention rating. If no connection is made between 2 tokens, then an attention rating is assumed to 0.
BigBird block sparse attention is a mixture of sliding, global & random connections (total 10 connections) as shown in gif in left. While a graph of normal attention (right) could have all 15 connections (note: total 6 nodes are present). You possibly can simply consider normal attention as all of the tokens attending globally .
Normal attention: Model can transfer information from one token to a different token directly in a single layer since each token is queried over every other token and is attended by every other token. Let’s consider an example much like what’s shown within the above figures. If the model must associate ‘going’ with ‘now’, it could possibly simply do this in a single layer since there may be a direct connection joining each the tokens.
Block sparse attention: If the model must share information between two nodes (or tokens), information could have to travel across various other nodes in the trail for a number of the tokens; since all of the nodes are usually not directly connected in a single layer.
Eg., assuming model must associate ‘going’ with ‘now’, then if only sliding attention is present the flow of data amongst those 2 tokens, is defined by the trail: going -> am -> i -> now (i.e. it is going to must travel over 2 other tokens). Hence, we may have multiple layers to capture your complete information of the sequence. Normal attention can capture this in a single layer. In an extreme case, this might mean that as many layers as input tokens are needed. If, nonetheless, we introduce some global tokens information can travel via the trail: going -> i -> now (which is shorter). If we as well as introduce random connections it could possibly travel via: going -> am -> now. With the assistance of random connections & global connections, information can travel very rapidly (with just just a few layers) from one token to the subsequent.
In case, we’ve many global tokens, then we may not need random connections since there might be multiple short paths through which information can travel. That is the thought behind keeping num_random_tokens = 0 when working with a variant of BigBird, called ETC (more on this in later sections).
In these graphics, we’re assuming that the eye matrix is symmetric i.e. since in a graph if some token A attends B, then B will even attend A. You possibly can see from the figure of the eye matrix shown in the subsequent section that this assumption holds for many tokens in BigBird
| Attention Type | global_tokens |
sliding_tokens |
random_tokens |
|---|---|---|---|
original_full |
n |
0 | 0 |
block_sparse |
2 x block_size |
3 x block_size |
num_random_blocks x block_size |
original_full represents BERT‘s attention while block_sparse represents BigBird‘s attention. Wondering what the block_size is? We’ll cover that in later sections. For now, consider it to be 1 for simplicity
BigBird block sparse attention
BigBird block sparse attention is just an efficient implementation of what we discussed above. Each token is attending some global tokens, sliding tokens, & random tokens as a substitute of attending to all other tokens. The authors hardcoded the eye matrix for multiple query components individually; and used a cool trick to hurry up training/inference on GPU and TPU.

Note: on the highest, we’ve 2 extra sentences. As you possibly can notice, every token is just switched by one place in each sentences. That is how sliding attention is implemented. When q[i] is multiplied with k[i,0:3], we’ll get a sliding attention rating for q[i] (where i is index of element in sequence).
Yow will discover the actual implementation of block_sparse attention here. This will likely look very scary 😨😨 now. But this text will certainly ease your life in understanding the code.
Global Attention
For global attention, each query is solely attending to all the opposite tokens within the sequence & is attended by every other token. Let’s assume Vasudev (1st token) & them (last token) to be global (within the above figure). You possibly can see that these tokens are directly connected to all other tokens (blue boxes).
Q -> Query martix (seq_length, head_dim)
K -> Key matrix (seq_length, head_dim)
Q[0] x [K[0], K[1], K[2], ......, K[n-1]]
Q[n-1] x [K[0], K[1], K[2], ......, K[n-1]]
K[0] x [Q[0], Q[1], Q[2], ......, Q[n-1]]
K[n-1] x [Q[0], Q[1], Q[2], ......, Q[n-1]]
Sliding Attention
The sequence of key tokens is copied 2 times with each element shifted to the fitting in one in every of the copies and to the left in the opposite copy. Now if we multiply query sequence vectors by these 3 sequence vectors, we’ll cover all of the sliding tokens. Computational complexity is solely O(3xn) = O(n). Referring to the above picture, the orange boxes represent the sliding attention. You possibly can see 3 sequences at the highest of the figure with 2 of them shifted by one token (1 to the left, 1 to the fitting).
Q[i] x [K[i-1], K[i], K[i+1]] for i = 1:-1
[Q[0], Q[1], Q[2], ......, Q[n-2], Q[n-1]] x [K[1], K[2], K[3], ......, K[n-1], K[0]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[n-1], K[0], K[1], ......, K[n-2]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[0], K[1], K[2], ......, K[n-1]]
Random Attention
Random attention is ensuring that every query token will attend just a few random tokens as well. For the actual implementation, because of this the model gathers some tokens randomly and computes their attention rating.
Q[1] x [K[r1], K[r2], ......, K[r]]
.
.
.
Q[n-2] x [K[r1], K[r2], ......, K[r]]
Note: The present implementation further divides sequence into blocks & each notation is defined w.r.to dam as a substitute of tokens. Let’s discuss this in additional detail in the subsequent section.
Implementation
Recap: In regular BERT attention, a sequence of tokens i.e. is projected through a dense layer into and the eye rating is calculated as . Within the case of BigBird block sparse attention, the identical algorithm is used but only with some chosen query & key vectors.
Let’s have a take a look at how bigbird block sparse attention is implemented. To start with, let’s assume represent block_size, num_random_blocks, num_sliding_blocks, num_global_blocks, respectively. Visually, we are able to illustrate the components of massive bird’s block sparse attention with as follows:

Attention scores for are calculated individually as described below:
Attention rating for represented by where , is nothing but attention rating between all of the tokens in 1st block with all the opposite tokens within the sequence.
represents 1st block, represents block. We’re simply performing normal attention operation between & (i.e. all of the keys).
For calculating attention rating for tokens in second block, we’re gathering the primary three blocks, the last block, and the fifth block. Then we are able to compute .
I’m representing tokens by simply to represent their nature explicitly (i.e. showing global, random, sliding tokens), else they’re only.
For calculating attention rating for , we’ll gather global, sliding, random keys & will compute the traditional attention operation over and the gathered keys. Note that sliding keys are gathered using the special shifting trick as discussed earlier within the sliding attention section.
For calculating attention rating for tokens in previous to last block (i.e. ), we’re gathering the primary block, last three blocks, and the third block. Then we are able to apply the formula . This could be very much like what we did for .
Attention rating for is represented by where , and is nothing but attention rating between all of the tokens within the last block with all the opposite tokens in sequence. This could be very much like what we did for .
Let’s mix the above matrices to get the ultimate attention matrix. This attention matrix could be used to get a representation of all of the tokens.
blue -> global blocks, red -> random blocks, orange -> sliding blocks This attention matrix is only for illustration. In the course of the forward pass, we aren’t storing white blocks, but are computing a weighted value matrix (i.e. representation of every token) directly for every separated components as discussed above.
Now, we’ve covered the toughest a part of block sparse attention, i.e. its implementation. Hopefully, you now have a greater background to grasp the actual code. Be happy to dive into it and to attach each a part of the code with one in every of the components above.
Time & Memory complexity
| Attention Type | Sequence length | Time & Memory Complexity |
|---|---|---|
original_full |
512 | T |
| 1024 | 4 x T |
|
| 4096 | 64 x T |
|
block_sparse |
1024 | 2 x T |
| 4096 | 8 x T |
Comparison of time & space complexity of BERT attention and BigBird block sparse attention.
Expand this snippet in case you wanna see the calculations
BigBird time complexity = O(w x n + r x n + g x n)
BERT time complexity = O(n^2)
Assumptions:
w = 3 x 64
r = 3 x 64
g = 2 x 64
When seqlen = 512
=> **time complexity in BERT = 512^2**
When seqlen = 1024
=> time complexity in BERT = (2 x 512)^2
=> **time complexity in BERT = 4 x 512^2**
=> time complexity in BigBird = (8 x 64) x (2 x 512)
=> **time complexity in BigBird = 2 x 512^2**
When seqlen = 4096
=> time complexity in BERT = (8 x 512)^2
=> **time complexity in BERT = 64 x 512^2**
=> compute in BigBird = (8 x 64) x (8 x 512)
=> compute in BigBird = 8 x (512 x 512)
=> **time complexity in BigBird = 8 x 512^2**
ITC vs ETC
The BigBird model could be trained using 2 different strategies: ITC & ETC. ITC (internal transformer construction) is solely what we discussed above. In ETC (prolonged transformer construction), some additional tokens are made global such that they may attend to / might be attended by all tokens.
ITC requires less compute since only a few tokens are global while at the identical time the model can capture sufficient global information (also with the assistance of random attention). Alternatively, ETC could be very helpful for tasks wherein we want numerous global tokens akin to `question-answering for which your complete query must be attended to globally by the context to give you the chance to relate the context appropriately to the query.
Note: It’s shown within the Big Bird paper that in lots of ETC experiments, the variety of random blocks is about to 0. This is cheap given our discussions above within the graph section.
The table below summarizes ITC & ETC:
| ITC | ETC | |
|---|---|---|
| Attention Matrix with global attention | ||
global_tokens |
2 x block_size |
extra_tokens + 2 x block_size |
random_tokens |
num_random_blocks x block_size |
num_random_blocks x block_size |
sliding_tokens |
3 x block_size |
3 x block_size |
Using BigBird with 🤗Transformers
You should utilize BigBirdModel just like every other 🤗 model. Let’s examine some code below:
from transformers import BigBirdModel
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", num_random_blocks=2, block_size=16)
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")
There are total 3 checkpoints available in 🤗Hub (at the purpose of writing this text): bigbird-roberta-base, bigbird-roberta-large, bigbird-base-trivia-itc. The primary two checkpoints come from pretraining BigBirdForPretraining with masked_lm loss; while the last one corresponds to the checkpoint after finetuning BigBirdForQuestionAnswering on trivia-qa dataset.
Let’s have a take a look at minimal code you possibly can write (in case you prefer to use your PyTorch trainer), to make use of 🤗’s BigBird model for fine-tuning your tasks.
from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base", block_size=64, num_random_blocks=3)
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model.to(device)
dataset = "torch.utils.data.DataLoader object"
optimizer = "torch.optim object"
epochs = ...
for e in range(epochs):
for batch in dataset:
model.train()
batch = {k: batch[k].to(device) for k in batch}
output = model(**batch)
output["loss"].backward()
optimizer.step()
optimizer.zero_grad()
model.save_pretrained("" )
from huggingface_hub import ModelHubMixin
ModelHubMixin.push_to_hub("" , model_id="" )
query = ["How are you doing?", "How is life going?"]
context = ["" , "" ]
batch = tokenizer(query, context, return_tensors="pt")
batch = {k: batch[k].to(device) for k in batch}
model = BigBirdForQuestionAnswering.from_pretrained("" )
model.to(device)
with torch.no_grad():
start_logits, end_logits = model(**batch).to_tuple()
It is important to maintain the next points in mind while working with big bird:
- Sequence length should be a multiple of block size i.e.
seqlen % block_size = 0. You wish not worry since 🤗Transformers will mechanically(to smallest multiple of block size which is larger than sequence length) if batch sequence length isn’t a multiple ofblock_size. - Currently, HuggingFace version doesn’t support ETC and hence only 1st & last block might be global.
- Current implementation doesn’t support
num_random_blocks = 0. - It’s advisable by authors to set
attention_type = "original_full"when sequence length < 1024. - This must hold:
seq_length > global_token + random_tokens + sliding_tokens + buffer_tokenswhereglobal_tokens = 2 x block_size,sliding_tokens = 3 x block_size,random_tokens = num_random_blocks x block_size&buffer_tokens = num_random_blocks x block_size. In case you fail to try this, 🤗Transformers will mechanically switchattention_typetooriginal_fullwith a warning. - When using big bird as decoder (or using
BigBirdForCasualLM),attention_typemust beoriginal_full. But you would like not worry, 🤗Transformers will mechanically switchattention_typetooriginal_fullin case you forget to try this.
What’s next?
@patrickvonplaten has made a very cool notebook on the best way to evaluate BigBirdForQuestionAnswering on the trivia-qa dataset. Be happy to play with BigBird using that notebook.
You’ll soon find BigBird Pegasus-like model within the library for long document summarization💥.
End Notes
The unique implementation of block sparse attention matrix could be found here. Yow will discover 🤗’s version here.







