Transformers have exhibited remarkable performance on various Natural Language Processing and Computer Vision tasks. Their success may be attributed to the self-attention mechanism, which captures the pairwise interactions between all of the tokens in an input. Nonetheless, the usual self-attention mechanism has a time and memory complexity of (where is the length of the input sequence), making it expensive to coach on long input sequences.
The Nyströmformer is one in all many efficient Transformer models that approximates standard self-attention with complexity. Nyströmformer exhibits competitive performance on various downstream NLP and CV tasks while improving upon the efficiency of normal self-attention. The aim of this blog post is to offer readers an summary of the Nyström method and the way it could possibly be adapted to approximate self-attention.
Nyström method for matrix approximation
At the guts of Nyströmformer is the Nyström method for matrix approximation. It allows us to approximate a matrix by sampling a few of its rows and columns. Let’s consider a matrix , which is pricey to compute in its entirety. So, as a substitute, we approximate it using the Nyström method. We start by sampling rows and columns from . We will then arrange the sampled rows and columns as follows:
We now have 4 submatrices: and , with sizes and respectively. The sampled columns are contained in and , whereas the sampled rows are contained in and . So, the entries of and are known to us, and we are going to estimate . In accordance with the Nyström method, is given by:
Here, denotes the Moore-Penrose inverse (or pseudoinverse).
Thus, the Nyström approximation of may be written as:
As shown within the second line, may be expressed as a product of three matrices. The explanation for doing so will turn out to be clear later.
Can we approximate self-attention with the Nyström method?
Our goal is to ultimately approximate the softmax matrix in standard self attention: S = softmax
Here, and denote the queries and keys respectively. Following the procedure discussed above, we might sample rows and columns from , form 4 submatrices, and acquire :
But, what does it mean to sample a column from ? It means we select one element from each row. Recall how S is calculated: the ultimate operation is a row-wise softmax. To search out a single entry in a row, we must access all other entries (for the denominator in softmax). So, sampling one column requires us to know all other columns within the matrix. Subsequently, we cannot directly apply the Nyström method to approximate the softmax matrix.
How can we adapt the Nyström method to approximate self-attention?
As an alternative of sampling from , the authors propose to sample landmarks (or Nyström points) from queries and keys. We denote the query landmarks and key landmarks as and respectively. and may be used to construct three matrices corresponding to those within the Nyström approximation of . We define the next matrices:
The sizes of , , and and respectively.
We replace the three matrices within the Nyström approximation of with the brand new matrices now we have defined to acquire another Nyström approximation:
That is the Nyström approximation of the softmax matrix within the self-attention mechanism. We multiply this matrix with the values ( ) to acquire a linear approximation of self-attention. Note that we never calculated the product , avoiding the complexity.
How will we select landmarks?
As an alternative of sampling rows from and , the authors propose to construct and
using segment means. On this procedure, tokens are grouped into segments, and the mean of every segment is computed. Ideally, is way smaller than . In accordance with experiments from the paper, choosing just or landmarks produces competetive performance compared to plain self-attention and other efficient attention mechanisms, even for long sequences lengths ( or ).
The general algorithm is summarised by the next figure from the paper:
The three orange matrices above correspond to the three matrices we constructed using the important thing and query landmarks. Also, notice that there’s a DConv box. This corresponds to a skip connection added to the values using a 1D depthwise convolution.
How is Nyströmformer implemented?
The unique implementation of Nyströmformer may be found here and the HuggingFace implementation may be found here. Let’s take a have a look at just a few lines of code (with some comments added) from the HuggingFace implementation. Note that some details corresponding to normalization, attention masking, and depthwise convolution are avoided for simplicity.
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
q_landmarks = query_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2)
k_landmarks = key_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2)
kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1)
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1)
attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2))
kernel_3 = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2))
new_value_layer = torch.matmul(kernel_3, value_layer)
context_layer = torch.matmul(attention_probs, new_value_layer)
Using Nyströmformer with HuggingFace
Nyströmformer for Masked Language Modeling (MLM) is offered on HuggingFace. Currently, there are 4 checkpoints, corresponding to numerous sequence lengths: nystromformer-512, nystromformer-1024, nystromformer-2048, and nystromformer-4096. The variety of landmarks, , may be controlled using the num_landmarks parameter within the NystromformerConfig. Let’s take a have a look at a minimal example of Nyströmformer for MLM:
from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")
inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)
Output:
----------------------------------------------------------------------------------------------------
capital
Alternatively, we will use the pipeline API (which handles all of the complexity for us):
from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")
Output:
----------------------------------------------------------------------------------------------------
[{'score': 0.829957902431488,
'token': 1030,
'token_str': 'capital',
'sequence': 'paris is the capital of france.'},
{'score': 0.022157637402415276,
'token': 16081,
'token_str': 'birthplace',
'sequence': 'paris is the birthplace of france.'},
{'score': 0.01904447190463543,
'token': 197,
'token_str': 'name',
'sequence': 'paris is the name of france.'},
{'score': 0.017583081498742104,
'token': 1107,
'token_str': 'kingdom',
'sequence': 'paris is the kingdom of france.'},
{'score': 0.005948934704065323,
'token': 148,
'token_str': 'city',
'sequence': 'paris is the city of france.'}]
Conclusion
Nyströmformer offers an efficient approximation to the usual self-attention mechanism, while outperforming other linear self-attention schemes. On this blog post, we went over a high-level overview of the Nyström method and the way it could possibly be leveraged for self-attention. Readers fascinated with deploying or fine-tuning Nyströmformer for downstream tasks can find the HuggingFace documentation here.
