Home Artificial Intelligence Decoding Strategies in Large Language Models 📚 Background 🏃‍♂️ Greedy Search ⚖️ Beam Search 🎲 Top-k sampling 🔬 Nucleus sampling Conclusion

Decoding Strategies in Large Language Models 📚 Background 🏃‍♂️ Greedy Search ⚖️ Beam Search 🎲 Top-k sampling 🔬 Nucleus sampling Conclusion

0
Decoding Strategies in Large Language Models
📚 Background
🏃‍♂️ Greedy Search
⚖️ Beam Search
🎲 Top-k sampling
🔬 Nucleus sampling
Conclusion

The tokenizer, Byte-Pair Encoding on this instance, translates each token within the input text right into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the subsequent most certainly token. Finally, the model generates logits, that are converted into probabilities using a softmax function.

For instance, the model assigns a probability of 17% to the token for “of” being the subsequent token after “I even have a dream”. This output essentially represents a ranked list of potential next tokens within the sequence. More formally, we denote this probability as P(of | I even have a dream) = 17%.

Autoregressive models like GPT predict the subsequent token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (w, w, …, w). The joint probability of this sequence P(w) might be broken down as:

For every token wᵢ within the sequence, P(wᵢ | w₁, w₂, …, wᵢ₋₁) represents the conditional probability of wᵢ given all of the preceding tokens (w₁, w₂, …, wᵢ₋₁). GPT-2 calculates this conditional probability for every of the 50,257 tokens in its vocabulary.

This results in the query: how will we use these probabilities to generate text? That is where decoding strategies, similar to greedy search and beam search, come into play.

Greedy search is a decoding method that takes essentially the most probable token at each step as the subsequent token within the sequence. To place it simply, it only retains the most certainly token at each stage, discarding all other potential options. Using our example:

  • : Input: “I even have a dream” → Most probably token: “ of”
  • : Input: “I even have a dream of” → Most probably token: “ being”
  • : Input: “I even have a dream of being” → Most probably token: “ a”
  • : Input: “I even have a dream of being a” → Most probably token: “ doctor”
  • : Input: “I even have a dream of being a health care provider” → Most probably token: “.”

While this approach might sound intuitive, it’s essential to notice that the greedy search is short-sighted: it only considers essentially the most probable token at each step without considering the general effect on the sequence. This property makes it fast and efficient because it doesn’t have to keep track of multiple sequences, nevertheless it also implies that it could actually miss out on higher sequences that might need appeared with barely less probable next tokens.

Next, let’s illustrate the greedy search implementation using graphviz and networkx. We select the ID with the very best rating, compute its log probability (we take the log to simplify calculations), and add it to the tree. We’ll repeat this process for five tokens.

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import time

def get_log_prob(logits, token_id):
# Compute the softmax of the logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
log_probabilities = torch.log(probabilities)

# Get the log probability of the token
token_log_probability = log_probabilities[token_id].item()
return token_log_probability

def greedy_search(input_ids, node, length=5):
if length == 0:
return input_ids

outputs = model(input_ids)
predictions = outputs.logits

# Get the anticipated next sub-word (here we use top-k search)
logits = predictions[0, -1, :]
token_id = torch.argmax(logits).unsqueeze(0)

# Compute the rating of the anticipated token
token_score = get_log_prob(logits, token_id)

# Add the anticipated token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)

# Add node and edge to graph
next_token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[0]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['token'] = next_token + f"_{length}"

# Recursive call
input_ids = greedy_search(new_input_ids, current_node, length-1)

return input_ids

# Parameters
length = 5
beams = 1

# Create a balanced tree with height 'length'
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())

# Add 'tokenscore', 'cumscore', and 'token' attributes to every node
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['token'] = text

# Start generating text
output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

Generated text: I even have a dream of being a health care provider.

Our greedy search generates the identical text because the one from the transformers library: “I even have a dream of being a health care provider.” Let’s visualize the tree we created.

import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.colours as mcolors
from matplotlib.colours import LinearSegmentedColormap

def plot_graph(graph, length, beams, rating):
fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')

# Create positions for every node
pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")

# Normalize the colours along the range of token scores
if rating == 'token':
scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] shouldn't be None]
elif rating == 'sequence':
scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] shouldn't be None]
vmin = min(scores)
vmax = max(scores)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)

# Draw the nodes
nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,
node_color=scores, cmap=cmap)

# Draw the sides
nx.draw_networkx_edges(graph, pos)

# Draw the labels
if rating == 'token':
labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] shouldn't be None}
elif rating == 'sequence':
labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] shouldn't be None}
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)
plt.box(False)

# Add a colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
if rating == 'token':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')
elif rating == 'sequence':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence rating')
plt.show()

# Plot graph
plot_graph(graph, length, 1.5, 'token')

Image by creator.

On this graph, the highest node stores the input token (thus with a 100% probability), while all other nodes represent generated tokens. Although each token on this sequence was the most certainly on the time of prediction, “being” and “doctor” were assigned relatively low probabilities of 9.68% and a pair of.86%, respectively. This implies that “of”, our first predicted token, may not have been essentially the most suitable selection because it led to “being”, which is kind of unlikely.

In the next section, we’ll explore how beam search can address this problem.

Unlike greedy search, which only considers the subsequent most probable token, beam search takes under consideration the n most certainly tokens, where n represents the variety of beams. This procedure is repeated until a predefined maximum length is reached or an end-of-sequence token appears. At this point, the sequence (or “beam”) with the very best overall rating is chosen because the output.

We are able to adapt the previous function to think about the n most probable tokens as a substitute of only one. Here, we’ll maintain the sequence rating log P(w), which is the cumulative sum of the log probability of each token within the beam. We normalize this rating by the sequence length to stop bias towards longer sequences (this factor might be adjusted). Once more, we’ll generate five additional tokens to finish the sentence “I even have a dream.”

from tqdm.notebook import tqdm

def greedy_sampling(logits, beams):
return torch.topk(logits, beams).indices

def beam_search(input_ids, node, bar, length, beams, sampling, temperature=0.1):
if length == 0:
return None

outputs = model(input_ids)
predictions = outputs.logits

# Get the anticipated next sub-word (here we use top-k search)
logits = predictions[0, -1, :]

if sampling == 'greedy':
top_token_ids = greedy_sampling(logits, beams)
elif sampling == 'top_k':
top_token_ids = top_k_sampling(logits, temperature, 20, beams)
elif sampling == 'nucleus':
top_token_ids = nucleus_sampling(logits, temperature, 0.5, beams)

for j, token_id in enumerate(top_token_ids):
bar.update(1)

# Compute the rating of the anticipated token
token_score = get_log_prob(logits, token_id)
cumulative_score = graph.nodes[node]['cumscore'] + token_score

# Add the anticipated token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)

# Add node and edge to graph
token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[j]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['cumscore'] = cumulative_score
graph.nodes[current_node]['sequencescore'] = 1/(len(new_input_ids.squeeze())) * cumulative_score
graph.nodes[current_node]['token'] = token + f"_{length}_{j}"

# Recursive call
beam_search(new_input_ids, current_node, bar, length-1, beams, sampling, 1)

# Parameters
length = 5
beams = 2

# Create a balanced tree with height 'length' and branching factor 'k'
graph = nx.balanced_tree(beams, length, create_using=nx.DiGraph())
bar = tqdm(total=len(graph.nodes))

# Add 'tokenscore', 'cumscore', and 'token' attributes to every node
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['cumscore'] = 0
graph.nodes[node]['sequencescore'] = 0
graph.nodes[node]['token'] = text

# Start generating text
beam_search(input_ids, 0, bar, length, beams, 'greedy', 1)

The function computes the scores for 63 tokens and beams^length = 5² = 25 possible sequences. In our implementation, all the knowledge is stored within the graph. Our next step is to extract the most effective sequence.

First, we discover the leaf node with the very best sequence rating. Next, we discover the shortest path from the basis to this leaf. Every node along this path comprises a token from the optimal sequence. Here’s how we are able to implement it:

def get_best_sequence(G):
# Create an inventory of leaf nodes
leaf_nodes = [node for node in G.nodes() if G.out_degree(node)==0]

# Get the leaf node with the very best cumscore
max_score_node = None
max_score = float('-inf')
for node in leaf_nodes:
if G.nodes[node]['sequencescore'] > max_score:
max_score = G.nodes[node]['sequencescore']
max_score_node = node

# Retrieve the sequence of nodes from this leaf node to the basis node in an inventory
path = nx.shortest_path(G, source=0, goal=max_score_node)

# Return the string of token attributes of this sequence
sequence = "".join([G.nodes[node]['token'].split('_')[0] for node in path])

return sequence, max_score

sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")

Generated text: I even have a dream. I even have a dream

One of the best sequence appears to be “I even have a dream. I even have a dream,” which is a standard response from GPT-2, although it could be surprising. To confirm this, let’s plot the graph.

On this visualization, we’ll display the sequence rating for every node, which represents the rating of the sequence as much as that time. If the function get_best_sequence() is correct, the “dream” node within the sequence “I even have a dream. I even have a dream” must have the very best rating amongst all of the leaf nodes.

# Plot graph
plot_graph(graph, length, beams, 'sequence')

LEAVE A REPLY

Please enter your comment!
Please enter your name here