Approximating self-attention in linear time and memory via the Nyström method

-


Antoine SIMOULIN's avatar

Transformers have exhibited remarkable performance on various Natural Language Processing and Computer Vision tasks. Their success may be attributed to the self-attention mechanism, which captures the pairwise interactions between all of the tokens in an input. Nonetheless, the usual self-attention mechanism has a time and memory complexity of O(n2)O(n^2)

The Nyströmformer is one in all many efficient Transformer models that approximates standard self-attention with O(n)O(n) complexity. Nyströmformer exhibits competitive performance on various downstream NLP and CV tasks while improving upon the efficiency of normal self-attention. The aim of this blog post is to offer readers an summary of the Nyström method and the way it could possibly be adapted to approximate self-attention.



Nyström method for matrix approximation

At the guts of Nyströmformer is the Nyström method for matrix approximation. It allows us to approximate a matrix by sampling a few of its rows and columns. Let’s consider a matrix Pn×nP^{n times n}

Representing P as a block matrix

We now have 4 submatrices: AP,BP,FP,A_P, B_P, F_P,

CP=FPAP+BPC_P = F_P A_P^+ B_P

Here, ++ denotes the Moore-Penrose inverse (or pseudoinverse).
Thus, the Nyström approximation of P,P^P, hat{P}

Nyström approximation of P

As shown within the second line, P^hat{P}



Can we approximate self-attention with the Nyström method?

Our goal is to ultimately approximate the softmax matrix in standard self attention: S = softmax QKTd frac{QK^T}{sqrt{d}}

Here, QQ and KK denote the queries and keys respectively. Following the procedure discussed above, we might sample mm rows and columns from SS, form 4 submatrices, and acquire S^hat{S}

Nyström approximation of S

But, what does it mean to sample a column from SS? It means we select one element from each row. Recall how S is calculated: the ultimate operation is a row-wise softmax. To search out a single entry in a row, we must access all other entries (for the denominator in softmax). So, sampling one column requires us to know all other columns within the matrix. Subsequently, we cannot directly apply the Nyström method to approximate the softmax matrix.



How can we adapt the Nyström method to approximate self-attention?

As an alternative of sampling from SS, the authors propose to sample landmarks (or Nyström points) from queries and keys. We denote the query landmarks and key landmarks as Q~tilde{Q}

F~=softmax(QK~Td)A~=softmax(Q~K~Td)+B~=softmax(Q~KTd)tilde{F} = softmax(frac{Qtilde{K}^T}{sqrt{d}}) hspace{40pt} tilde{A} = softmax(frac{tilde{Q}tilde{K}^T}{sqrt{d}})^+ hspace{40pt} tilde{B} = softmax(frac{tilde{Q}K^T}{sqrt{d}})

The sizes of F~tilde{F}

S^=F~A~B~=softmax(QK~Td)softmax(Q~K~Td)+softmax(Q~KTd)begin{aligned}hat{S} &= tilde{F} tilde{A} tilde{B} &= softmax(frac{Qtilde{K}^T}{sqrt{d}}) softmax(frac{tilde{Q}tilde{K}^T}{sqrt{d}})^+ softmax(frac{tilde{Q}K^T}{sqrt{d}}) end{aligned}

That is the Nyström approximation of the softmax matrix within the self-attention mechanism. We multiply this matrix with the values ( VV) to acquire a linear approximation of self-attention. Note that we never calculated the product QKTQK^T



How will we select landmarks?

As an alternative of sampling mm rows from QQ and KK, the authors propose to construct Q~tilde{Q}

The general algorithm is summarised by the next figure from the paper:

Efficient self-attention with the Nyström method

The three orange matrices above correspond to the three matrices we constructed using the important thing and query landmarks. Also, notice that there’s a DConv box. This corresponds to a skip connection added to the values using a 1D depthwise convolution.



How is Nyströmformer implemented?

The unique implementation of Nyströmformer may be found here and the HuggingFace implementation may be found here. Let’s take a have a look at just a few lines of code (with some comments added) from the HuggingFace implementation. Note that some details corresponding to normalization, attention masking, and depthwise convolution are avoided for simplicity.


key_layer = self.transpose_for_scores(self.key(hidden_states)) 
value_layer = self.transpose_for_scores(self.value(hidden_states)) 
query_layer = self.transpose_for_scores(mixed_query_layer) 

q_landmarks = query_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) 

k_landmarks = key_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) 

kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) 
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) 

attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) 

kernel_3 = nn.functional.softmax(attention_scores, dim=-1) 
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) 
new_value_layer = torch.matmul(kernel_3, value_layer) 
context_layer = torch.matmul(attention_probs, new_value_layer) 



Using Nyströmformer with HuggingFace

Nyströmformer for Masked Language Modeling (MLM) is offered on HuggingFace. Currently, there are 4 checkpoints, corresponding to numerous sequence lengths: nystromformer-512, nystromformer-1024, nystromformer-2048, and nystromformer-4096. The variety of landmarks, mm, may be controlled using the num_landmarks parameter within the NystromformerConfig. Let’s take a have a look at a minimal example of Nyströmformer for MLM:

from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")

inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits


mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)
Output:
----------------------------------------------------------------------------------------------------
capital

Alternatively, we will use the pipeline API (which handles all of the complexity for us):

from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")
Output:
----------------------------------------------------------------------------------------------------
[{'score': 0.829957902431488,
  'token': 1030,
  'token_str': 'capital',
  'sequence': 'paris is the capital of france.'},
{'score': 0.022157637402415276,
  'token': 16081,
  'token_str': 'birthplace',
  'sequence': 'paris is the birthplace of france.'},
{'score': 0.01904447190463543,
  'token': 197,
  'token_str': 'name',
  'sequence': 'paris is the name of france.'},
{'score': 0.017583081498742104,
  'token': 1107,
  'token_str': 'kingdom',
  'sequence': 'paris is the kingdom of france.'},
{'score': 0.005948934704065323,
  'token': 148,
  'token_str': 'city',
  'sequence': 'paris is the city of france.'}]



Conclusion

Nyströmformer offers an efficient approximation to the usual self-attention mechanism, while outperforming other linear self-attention schemes. On this blog post, we went over a high-level overview of the Nyström method and the way it could possibly be leveraged for self-attention. Readers fascinated with deploying or fine-tuning Nyströmformer for downstream tasks can find the HuggingFace documentation here.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x