Understanding BigBird’s Block Sparse Attention

-


Vasudev Gupta's avatar

Transformer-based models have shown to be very useful for a lot of NLP tasks. Nevertheless, a serious limitation of transformers-based models is its O(n2)O(n^2)

BigBird (introduced in paper) is one in every of such recent models to handle this issue. BigBird relies on block sparse attention as a substitute of normal attention (i.e. BERT’s attention) and might handle sequences as much as a length of 4096 at a much lower computational cost in comparison with BERT. It has achieved SOTA on various tasks involving very long sequences akin to long documents summarization, question-answering with long contexts.

BigBird RoBERTa-like model is now available in 🤗Transformers. The goal of this post is to offer the reader an in-depth understanding of massive bird implementation & ease one’s life in using BigBird with 🤗Transformers. But, before going into more depth, it is necessary to keep in mind that the BigBird's attention is an approximation of BERT‘s full attention and due to this fact doesn’t strive to be higher than BERT's full attention, but somewhat to be more efficient. It simply allows to use transformer-based models to for much longer sequences since BERT’s quadratic memory requirement quickly becomes unbearable. Simply put, if we might have ∞infty compute & ∞infty time, BERT’s attention can be preferred over block sparse attention (which we’re going to discuss on this post).

For those who wonder why we want more compute when working with longer sequences, this blog post is just best for you!


A few of the predominant questions one might need when working with standard BERT-like attention include:

  • Do all tokens really have to take care of all other tokens?
  • Why not compute attention only over necessary tokens?
  • How you can determine what tokens are necessary?
  • How you can attend to simply just a few tokens in a really efficient way?

On this blog post, we’ll try to reply those questions.



What tokens must be attended to?

We’ll give a practical example of how attention works by considering the sentence “BigBird is now available in HuggingFace for extractive query answering”.
In BERT-like attention, every word would simply attend to all other tokens. Put mathematically, this may mean that every queried token query-token∈{BigBird,is,now,available,in,HuggingFace,for,extractive,query,answering} text{query-token} in {text{BigBird},text{is},text{now},text{available},text{in},text{HuggingFace},text{for},text{extractive},text{query},text{answering}}

Let’s take into consideration a good choice of key tokens that a queried token actually only should attend to by writing some pseudo-code.
Will will assume that the token available is queried and construct a smart list of key tokens to take care of.

>>> 
>>> example = ['BigBird', 'is', 'now', 'available', 'in', 'HuggingFace', 'for', 'extractive', 'question', 'answering']

>>> 
>>> query_token = 'available'

>>> 
>>> key_tokens = [] 

Nearby tokens must be necessary because, in a sentence (sequence of words), the present word is extremely depending on neighboring past & future tokens. This intuition is the thought behind the concept of sliding attention.

>>> 
>>> 
>>> sliding_tokens = ["now", "available", "in"]

>>> 
>>> key_tokens.append(sliding_tokens)

Long-range dependencies: For some tasks, it’s crucial to capture long-range relationships between tokens. E.g., in `question-answering the model needs to check each token of the context to the entire query to give you the chance to determine which a part of the context is beneficial for an accurate answer. If many of the context tokens would just attend to other context tokens, but to not the query, it becomes much harder for the model to filter necessary context tokens from less necessary context tokens.

Now, BigBird proposes two ways of allowing long-term attention dependencies while staying computationally efficient.

  • Global tokens: Introduce some tokens which is able to attend to each token and that are attended by every token. Eg: “HuggingFace is constructing nice libraries for simple NLP”. Now, for instance ‘constructing’ is defined as a world token, and the model must know the relation amongst ‘NLP’ & ‘HuggingFace’ for some task (Note: these 2 tokens are at two extremes); Now having ‘constructing’ attend globally to all other tokens will probably help the model to associate ‘NLP’ with ‘HuggingFace’.
>>> 
>>> global_tokens = ["BigBird", "answering"]

>>> 
>>> key_tokens.append(global_tokens)
  • Random tokens: Select some tokens randomly which is able to transfer information by transferring to other tokens which in turn can transfer to other tokens. This will likely reduce the fee of data travel from one token to other.
>>> 
>>> 
>>> random_tokens = ["is"] 

>>> 
>>> key_tokens.append(random_tokens)

>>> 
>>> key_tokens
{'now', 'is', 'in', 'answering', 'available', 'BigBird'}


This fashion, the query token attends only to a subset of all possible tokens while yielding a superb approximation of full attention. The identical approach will is used for all other queried tokens. But remember, the entire point here is to approximate BERT‘s full attention as efficiently as possible. Simply making each queried token attend all key tokens because it’s done for BERT could be computed very effectively as a sequence of matrix multiplication on modern hardware, like GPUs. Nevertheless, a mixture of sliding, global & random attention appears to imply sparse matrix multiplication, which is harder to implement efficiently on modern hardware.
Certainly one of the foremost contributions of BigBird is the proposition of a block sparse attention mechanism that permits computing sliding, global & random attention effectively. Let’s look into it!



Understanding the necessity for global, sliding, random keys with Graphs

First, let’s get a greater understanding of global, sliding & random attention using graphs and take a look at to grasp how the mix of those three attention mechanisms yields a superb approximation of ordinary Bert-like attention.



The above figure shows global (left), sliding (middle) & random (right) connections respectively as a graph. Each node corresponds to a token and every line represents an attention rating. If no connection is made between 2 tokens, then an attention rating is assumed to 0.


BigBird block sparse attention is a mixture of sliding, global & random connections (total 10 connections) as shown in gif in left. While a graph of normal attention (right) could have all 15 connections (note: total 6 nodes are present). You possibly can simply consider normal attention as all of the tokens attending globally 1 {}^1

Normal attention: Model can transfer information from one token to a different token directly in a single layer since each token is queried over every other token and is attended by every other token. Let’s consider an example much like what’s shown within the above figures. If the model must associate ‘going’ with ‘now’, it could possibly simply do this in a single layer since there may be a direct connection joining each the tokens.

Block sparse attention: If the model must share information between two nodes (or tokens), information could have to travel across various other nodes in the trail for a number of the tokens; since all of the nodes are usually not directly connected in a single layer.
Eg., assuming model must associate ‘going’ with ‘now’, then if only sliding attention is present the flow of data amongst those 2 tokens, is defined by the trail: going -> am -> i -> now (i.e. it is going to must travel over 2 other tokens). Hence, we may have multiple layers to capture your complete information of the sequence. Normal attention can capture this in a single layer. In an extreme case, this might mean that as many layers as input tokens are needed. If, nonetheless, we introduce some global tokens information can travel via the trail: going -> i -> now (which is shorter). If we as well as introduce random connections it could possibly travel via: going -> am -> now. With the assistance of random connections & global connections, information can travel very rapidly (with just just a few layers) from one token to the subsequent.

In case, we’ve many global tokens, then we may not need random connections since there might be multiple short paths through which information can travel. That is the thought behind keeping num_random_tokens = 0 when working with a variant of BigBird, called ETC (more on this in later sections).

1 {}^1

Attention Type global_tokens sliding_tokens random_tokens
original_full n 0 0
block_sparse 2 x block_size 3 x block_size num_random_blocks x block_size

original_full represents BERT‘s attention while block_sparse represents BigBird‘s attention. Wondering what the block_size is? We’ll cover that in later sections. For now, consider it to be 1 for simplicity



BigBird block sparse attention

BigBird block sparse attention is just an efficient implementation of what we discussed above. Each token is attending some global tokens, sliding tokens, & random tokens as a substitute of attending to all other tokens. The authors hardcoded the eye matrix for multiple query components individually; and used a cool trick to hurry up training/inference on GPU and TPU.

BigBird block sparse attention
Note: on the highest, we’ve 2 extra sentences. As you possibly can notice, every token is just switched by one place in each sentences. That is how sliding attention is implemented. When q[i] is multiplied with k[i,0:3], we’ll get a sliding attention rating for q[i] (where i is index of element in sequence).

Yow will discover the actual implementation of block_sparse attention here. This will likely look very scary 😨😨 now. But this text will certainly ease your life in understanding the code.



Global Attention

For global attention, each query is solely attending to all the opposite tokens within the sequence & is attended by every other token. Let’s assume Vasudev (1st token) & them (last token) to be global (within the above figure). You possibly can see that these tokens are directly connected to all other tokens (blue boxes).



Q -> Query martix (seq_length, head_dim)
K -> Key matrix (seq_length, head_dim)


Q[0] x [K[0], K[1], K[2], ......, K[n-1]]
Q[n-1] x [K[0], K[1], K[2], ......, K[n-1]]


K[0] x [Q[0], Q[1], Q[2], ......, Q[n-1]]
K[n-1] x [Q[0], Q[1], Q[2], ......, Q[n-1]]



Sliding Attention

The sequence of key tokens is copied 2 times with each element shifted to the fitting in one in every of the copies and to the left in the opposite copy. Now if we multiply query sequence vectors by these 3 sequence vectors, we’ll cover all of the sliding tokens. Computational complexity is solely O(3xn) = O(n). Referring to the above picture, the orange boxes represent the sliding attention. You possibly can see 3 sequences at the highest of the figure with 2 of them shifted by one token (1 to the left, 1 to the fitting).


Q[i] x [K[i-1], K[i], K[i+1]] for i = 1:-1


[Q[0], Q[1], Q[2], ......, Q[n-2], Q[n-1]] x [K[1], K[2], K[3], ......, K[n-1], K[0]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[n-1], K[0], K[1], ......, K[n-2]]
[Q[0], Q[1], Q[2], ......, Q[n-1]] x [K[0], K[1], K[2], ......, K[n-1]]





Random Attention

Random attention is ensuring that every query token will attend just a few random tokens as well. For the actual implementation, because of this the model gathers some tokens randomly and computes their attention rating.


Q[1] x [K[r1], K[r2], ......, K[r]]
.
.
.
Q[n-2] x [K[r1], K[r2], ......, K[r]]


Note: The present implementation further divides sequence into blocks & each notation is defined w.r.to dam as a substitute of tokens. Let’s discuss this in additional detail in the subsequent section.



Implementation

Recap: In regular BERT attention, a sequence of tokens i.e. X=x1,x2,....,xn X = x_1, x_2, …., x_n

Let’s have a take a look at how bigbird block sparse attention is implemented. To start with, let’s assume b,r,s,gb, r, s, g

Attention scores for q1,q2,q3:n−2,qn−1,qn{q}_{1}, {q}_{2}, {q}_{3:n-2}, {q}_{n-1}, {q}_{n}


Attention rating for q1mathbf{q}_{1}

BigBird block sparse attention q1q_1


For calculating attention rating for tokens in second block, we’re gathering the primary three blocks, the last block, and the fifth block. Then we are able to compute a2=Softmax(q2∗concat(k1,k2,k3,k5,k7)a_2 = Softmax(q_2 * concat(k_1, k_2, k_3, k_5, k_7)

BigBird block sparse attention

I’m representing tokens by g,r,sg, r, s


For calculating attention rating for q3:n−2{q}_{3:n-2}

BigBird block sparse attention


For calculating attention rating for tokens in previous to last block (i.e. qn−1{q}_{n-1}

BigBird block sparse attention


Attention rating for qnmathbf{q}_{n}

BigBird block sparse attention


Let’s mix the above matrices to get the ultimate attention matrix. This attention matrix could be used to get a representation of all of the tokens.

BigBird block sparse attention

blue -> global blocks, red -> random blocks, orange -> sliding blocks This attention matrix is only for illustration. In the course of the forward pass, we aren’t storing white blocks, but are computing a weighted value matrix (i.e. representation of every token) directly for every separated components as discussed above.

Now, we’ve covered the toughest a part of block sparse attention, i.e. its implementation. Hopefully, you now have a greater background to grasp the actual code. Be happy to dive into it and to attach each a part of the code with one in every of the components above.



Time & Memory complexity

Attention Type Sequence length Time & Memory Complexity
original_full 512 T
1024 4 x T
4096 64 x T
block_sparse 1024 2 x T
4096 8 x T

Comparison of time & space complexity of BERT attention and BigBird block sparse attention.

Expand this snippet in case you wanna see the calculations
BigBird time complexity = O(w x n + r x n + g x n)
BERT time complexity = O(n^2)

Assumptions:
    w = 3 x 64
    r = 3 x 64
    g = 2 x 64

When seqlen = 512
=> **time complexity in BERT = 512^2**

When seqlen = 1024
=> time complexity in BERT = (2 x 512)^2
=> **time complexity in BERT = 4 x 512^2**

=> time complexity in BigBird = (8 x 64) x (2 x 512)
=> **time complexity in BigBird = 2 x 512^2**

When seqlen = 4096
=> time complexity in BERT = (8 x 512)^2
=> **time complexity in BERT = 64 x 512^2**

=> compute in BigBird = (8 x 64) x (8 x 512)
=> compute in BigBird = 8 x (512 x 512)
=> **time complexity in BigBird = 8 x 512^2**



ITC vs ETC

The BigBird model could be trained using 2 different strategies: ITC & ETC. ITC (internal transformer construction) is solely what we discussed above. In ETC (prolonged transformer construction), some additional tokens are made global such that they may attend to / might be attended by all tokens.

ITC requires less compute since only a few tokens are global while at the identical time the model can capture sufficient global information (also with the assistance of random attention). Alternatively, ETC could be very helpful for tasks wherein we want numerous global tokens akin to `question-answering for which your complete query must be attended to globally by the context to give you the chance to relate the context appropriately to the query.

Note: It’s shown within the Big Bird paper that in lots of ETC experiments, the variety of random blocks is about to 0. This is cheap given our discussions above within the graph section.

The table below summarizes ITC & ETC:

ITC ETC
Attention Matrix with global attention A=[111111111111111111111111] A = begin{bmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 1 & & & & & & 1 1 & & & & & & 1 1 & & & & & & 1 1 & & & & & & 1 1 & & & & & & 1 1 & 1 & 1 & 1 & 1 & 1 & 1 end{bmatrix} B=[11111111111111111111111111111111111111111111111111111111] B = begin{bmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 1 & 1 & 1 & & & & & & 1 1 & 1 & 1 & & & & & & 1 1 & 1 & 1 & & & & & & 1 1 & 1 & 1 & & & & & & 1 1 & 1 & 1 & & & & & & 1 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 end{bmatrix}
global_tokens 2 x block_size extra_tokens + 2 x block_size
random_tokens num_random_blocks x block_size num_random_blocks x block_size
sliding_tokens 3 x block_size 3 x block_size



Using BigBird with 🤗Transformers

You should utilize BigBirdModel just like every other 🤗 model. Let’s examine some code below:

from transformers import BigBirdModel


model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")


model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", num_random_blocks=2, block_size=16)


model = BigBirdModel.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")

There are total 3 checkpoints available in 🤗Hub (at the purpose of writing this text): bigbird-roberta-base, bigbird-roberta-large, bigbird-base-trivia-itc. The primary two checkpoints come from pretraining BigBirdForPretraining with masked_lm loss; while the last one corresponds to the checkpoint after finetuning BigBirdForQuestionAnswering on trivia-qa dataset.

Let’s have a take a look at minimal code you possibly can write (in case you prefer to use your PyTorch trainer), to make use of 🤗’s BigBird model for fine-tuning your tasks.



from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
import torch

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")


model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base", block_size=64, num_random_blocks=3)
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model.to(device)

dataset = "torch.utils.data.DataLoader object"
optimizer = "torch.optim object"
epochs = ...


for e in range(epochs):
    for batch in dataset:
        model.train()
        batch = {k: batch[k].to(device) for k in batch}

        
        output = model(**batch)

        
        output["loss"].backward()
        optimizer.step()
        optimizer.zero_grad()


model.save_pretrained("")


from huggingface_hub import ModelHubMixin
ModelHubMixin.push_to_hub("", model_id="")


query = ["How are you doing?", "How is life going?"]
context = ["", ""]
batch = tokenizer(query, context, return_tensors="pt")
batch = {k: batch[k].to(device) for k in batch}

model = BigBirdForQuestionAnswering.from_pretrained("")
model.to(device)
with torch.no_grad():
    start_logits, end_logits = model(**batch).to_tuple()
    




It is important to maintain the next points in mind while working with big bird:

  • Sequence length should be a multiple of block size i.e. seqlen % block_size = 0. You wish not worry since 🤗Transformers will mechanically (to smallest multiple of block size which is larger than sequence length) if batch sequence length isn’t a multiple of block_size.
  • Currently, HuggingFace version doesn’t support ETC and hence only 1st & last block might be global.
  • Current implementation doesn’t support num_random_blocks = 0.
  • It’s advisable by authors to set attention_type = "original_full" when sequence length < 1024.
  • This must hold: seq_length > global_token + random_tokens + sliding_tokens + buffer_tokens where global_tokens = 2 x block_size, sliding_tokens = 3 x block_size, random_tokens = num_random_blocks x block_size & buffer_tokens = num_random_blocks x block_size. In case you fail to try this, 🤗Transformers will mechanically switch attention_type to original_full with a warning.
  • When using big bird as decoder (or using BigBirdForCasualLM), attention_type must be original_full. But you would like not worry, 🤗Transformers will mechanically switch attention_type to original_full in case you forget to try this.



What’s next?

@patrickvonplaten has made a very cool notebook on the best way to evaluate BigBirdForQuestionAnswering on the trivia-qa dataset. Be happy to play with BigBird using that notebook.

You’ll soon find BigBird Pegasus-like model within the library for long document summarization💥.



End Notes

The unique implementation of block sparse attention matrix could be found here. Yow will discover 🤗’s version here.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x