Open AI’s Whisper is a general
purpose speech transcription model that achieves state-of-the-art results across a spread of various benchmarks and
audio conditions. The most recent large-v3 model tops the
OpenASR Leaderboard, rating as the most effective open-source
speech transcription model for English. The model also demonstrates strong multilingual performance, achieving lower than
30% word error rate (WER) on 42 of the 58 languages tested within the Common Voice 15 dataset.
While the transcription accuracy is phenomenal, the inference time may be very slow. A 1 hour audio clip takes upwards of
6 minutes to transcribe on a 16GB T4 GPU, even after leveraging inference optimisations like flash attention,
half-precision, and chunking.
On this blog post, we reveal how Speculative Decoding may be employed to scale back the
inference time of Whisper by a factor of two, while mathematically ensuring precisely the same outputs are achieved
from the model. Consequently, this method provides an ideal drop-in substitute for existing Whisper pipelines, because it
provides free 2x speed-up while maintaining the identical accuracy. For a more streamlined version of the blog post
with fewer explanations but all of the code, see the accompanying Google Colab.
Speculative Decoding
Speculative Decoding was proposed in Fast Inference from Transformers via Speculative Decoding
by Yaniv Leviathan et. al. from Google. It really works on the premise that a faster, assistant model fairly often generates the identical tokens as a bigger essential model.
First, the assistant model auto-regressively generates a sequence of candidate tokens, .
Within the diagram below, the assistant model generates a sequence of 5 candidate tokens: The fast brown sock jumps.
While these candidate tokens are generated quickly, they might differ from those predicted by the essential model. Due to this fact,
within the second step, the candidate tokens are passed to the essential model to be “verified”. The essential model takes the
candidate tokens as input and performs a single forward pass. The outputs of the essential model are the “correct”
token for every step within the token sequence .
Within the diagram above, we see that the primary three tokens predicted by the essential model agree with those from the assistant
model: The fast brown. Nevertheless, the fourth candidate token from the assistant model,
sock, mismatches with the right token from the essential model, fox.
We all know that each one candidate tokens as much as the primary mismatch are correct (The fast brown),
since these agree with the predictions from the essential model. Nevertheless, after the primary mismatch, the candidate tokens
diverge from the actual tokens predicted by the essential model. Due to this fact, we are able to replace the primary incorrect candidate
token (sock) with the right token from the essential model (fox),
and discard all predicted tokens that come after this, since these have diverged. The corrected sequence, The fast brown fox,
now forms the brand new input to the assistant model:
The inference process then repeats, the assistant model generating a brand new set of candidate tokens, that are verified
in a single forward pass by the essential model.
Since we auto-regressively generate using the fast, assistant model, and only perform verification forward passes with
the slow, essential model, the decoding process is sped-up substantially. Moreover, the verification forward passes
performed by the essential model ensures that the exact same outputs are achieved as if we were using the essential model standalone.
This makes speculative decoding an ideal drop-in for existing Whisper pipelines, since one can make certain that the
same quality will probably be attained.
To get the largest improvement in latency, the assistant model needs to be significantly faster than the essential model,
while predicting the identical token distribution as often as possible. In practice, these two attributes form a trade-off:
the faster a model is, the less accurate it’s. Nevertheless, since 70-80% of all predicted tokens are likely to be “easier” tokens,
this trade-off is heavily biased towards choosing a faster model, moderately than a more accurate one. Thus, the assistant
model needs to be not less than 3x faster than the essential model (the more the higher), while predicting all of the “easy” tokens
within the examples accurately. The remaining 20-30% of more “difficult” tokens can then be verified by the larger, essential model.
The one constraint for choosing an assistant model is that it must share the identical vocabulary because the essential model. That’s
to say, the assistant model must use one-to-one the identical tokenizer because the essential model.
Due to this fact, if we wish to make use of speculative decoding with a multilingual variant of Whisper, e.g. large-v2
(multilingual), we’d like to pick a multilingual variant of Whisper because the assistant model, e.g. tiny.
Whereas, if we wish to make use of speculative decoding with and English-only version of Whisper, e.g. medium.en,
we’d like an English-only of version because the assistant model, e.g. tiny.en.
At the present time, Whisper large-v3 is an exception, since
it’s the only Whisper checkpoint with an expanded vocabulary size, and thus shouldn’t be compatible with previous Whisper
checkpoints.
Now that we all know the background behind speculative decoding, we’re able to dive into the sensible implementation. In
the 🤗 Transformers library, speculative decoding is implemented as
the “assisted generation” inference strategy. For more details concerning the implementation, the reader is suggested to read
Joao Gante’s excellent blog post on Assisted Generation.
English Speech Transcription
Baseline Implementation
We start by benchmarking Whisper large-v2 to get our baseline number
for inference speed. We are able to load the essential model and it’s corresponding processor via the convenient
AutoModelForSpeechSeq2Seq
and AutoProcessor classes. We’ll
load the model in float16 precision and be certain that that loading time takes as little time as possible by passing
low_cpu_mem_usage=True. As well as,
we wish to be certain that that the model is loaded in safetensors
format by passing use_safetensors=True.
Finally, we’ll pass the argument attn_implementation="sdpa" to profit from Flash Attention speed-ups through PyTorch’s
SDPA attention kernel:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
Let’s load the English speech transcription dataset that we’ll use for benchmarking. We’ll load a small dataset
consisting of 73 samples from the LibriSpeech ASR validation-clean
dataset. This amounts to ~9MB of information, so it is very lightweight and quick to download on device:
from datasets import load_dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
For the benchmark, we only need to measure the generation time, so let’s write a brief helper function that measures
this step. The next function will return each the decoded tokens and the time it took to run the model:
import time
def generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time
We are able to now iterate over the audio samples in our dataset and sum up the general generation time:
from tqdm import tqdm
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)
Output:
100%|██████████| 73/73 [01:37<00:00, 1.33s/it]
72.99542546272278
Alright! We see that transcribing the 73 samples took 73 seconds. Let’s check the WER of the predictions:
from evaluate import load
wer = load("wer")
print(wer.compute(predictions=predictions, references=references))
Output:
0.03507271171941831
Our final baseline number is 73 seconds for a WER of three.5%.
Speculative Decoding
Now let’s load the assistant model for speculative decoding. In this instance, we’ll use a distilled variant of Whisper,
distil-large-v2. The distilled model copies your entire encoder
from Whisper, but only 2 of the 32 decoder layers. As such, it runs 6x faster than Whisper, while performing to inside
1% WER on out-of-distribution test sets. This makes it the right selection as an assistant model, because it has each
high transcription accuracy and fast generation .
Since Distil-Whisper uses the exact same encoder because the Whisper model, we are able to share the encoder across the essential and
assistant models. We then only need to load the 2-layer decoder from Distil-Whisper as a “decoder-only” model. We are able to do
this through the convenient AutoModelForCausalLM
auto class. In practice, this leads to only an 8% increase to VRAM over using the essential model alone.
from transformers import AutoModelForCausalLM
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device)
We intend to release an improved variant of Distil-Whisper with a stronger alignment within the token distribution
that can improve speculative decoding performance further. Follow the Distil-Whisper repository
for updates.
We are able to define a modified function for our speculative decoding benchmark. The one difference from the previous function
is that we pass the assistant model to our call to .generate:
def assisted_generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time
Let’s run the benchmark with speculative decoding, using Distil-Whisper because the assistant to Whisper:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)
Outputs:
100%|██████████| 73/73 [00:38<00:00, 1.88it/s]
32.69683289527893
With speculative decoding, the inference time was just 33 seconds, 2.2x faster than before! Let’s confirm we have now the identical
WER:
print(wer.compute(predictions=predictions, references=references))
Outputs:
0.03507271171941831
Perfect! 3.5% WER again, as we have now similar outputs to using the essential model standalone.
Speculative decoding can be used with the simple 🤗 Transformers pipeline API for inference. Below, we instantiate the pipeline using the model and processor, after which use it to
transcribe the primary sample from the toy dataset. This may be prolonged to transcribe audio samples of arbitrary length,
including with using batching:
from transformers import pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=4,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
Outputs:
Mr. Quilter is the apostle of the center classes and we're glad to welcome his gospel.
An end-to-end code snippet for running speculative decoding with Whisper and Distil-Whisper may be found on the Distil-Whisper model card.
It combines the stages of inference covered on this notebook right into a single code example.
Multilingual Speech Transcription
Distil-Whisper is the right assistant model for English speech transcription, because it performs to inside 1% WER of the
original Whisper model, while being 6x faster over short and long-form audio samples. Nevertheless, the official Distil-Whisper
checkpoints are English only, meaning they can’t be used for multilingual speech transcription.
To make use of speculative decoding for multilingual speech transcription, one could either use considered one of the official multilingual Whisper checkpoints,
or a fine-tuned variant of Whisper. On the time of writing, there are over 5,000 fine-tuned Whisper checkpoints
on the Hugging Face Hub in over 100 languages. These provide a superb place to begin for choosing assistant Whisper
checkpoints that perform thoroughly on a single language. In this instance, we’ll use the smallest official multilingual
checkpoint, Whisper tiny. Be happy to experiment with different checkpoints
fine-tuned in your language!
Let’s load the weights for our recent assistant model, Whisper tiny. For the reason that encoder in Whisper tiny differs from that in
large-v2, this time we’ll load each the encoder and decoder using the AutoModelForSpeechSeq2Seq class:
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device);
For our benchmarking dataset, we’ll load 73 samples from the Dutch (“nl”) split of the VoxPopuli dataset:
dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")
Great! We are able to now re-run our benchmark for our baseline Whisper large-v2 model as before. The one change we make is that
we pass the language and task arguments to our generate function, with a view to ensure we perform speech transcription
(not speech translation). Speculative decoding is fully compatible with each the speech transcription and translation tasks. Simply
set the duty argument as required below:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)
Outputs:
100%|██████████| 73/73 [02:05<00:00, 1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146
Right! We have now our baseline time of 117 seconds and a WER of 12.8%. Let’s re-run the generation process using speculative decoding:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)
Outputs:
100%|██████████| 73/73 [01:08<00:00, 1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146
Again, we achieve 12.8% WER, but this time in only 62 seconds of inference time, representing a speed-up of 1.9x.
Given the low overhead of loading the assistant model and the mathematical property that the exact same outputs are
achieved, speculative decoding offers the right drop-in substitute to existing Whisper pipelines.
Strategies for Efficient Speculative Decoding
On this final section, we cover two strategies for ensuring the fastest possible inference time with speculative decoding.
Assistant Model
Our objective is to pick an assistant model that’s not less than 3x faster than the essential model and transcribes not less than
70-80% of the anticipated tokens accurately, typically the “easier” tokens within the examples. If you’ve got a specific
language during which you must transcribe, an efficient strategy is to coach two Whisper models of various sizes, and
use one because the assistant to the opposite:
- First, fine-tune Whisper large-v3 to act as your essential model
- Second, distil Whisper large-v3 on the identical dataset to act as a quick assistant model
Positive-tuning and distillation can improve the WER performance of each the essential and assistant models in your chosen language,
while maximising the alignment within the token distributions. An entire guide to Whisper fine-tuning may be found
here, and distillation here.
Batch Size
It’s value noting that the most important speed gains with speculative decoding include a batch size of 1. For batched
speculative decoding, all candidate tokens across the batch must match the validation tokens to ensure that the tokens
to be accepted. If a token within the batch at a given position doesn’t agree, all candidate tokens that proceed the position
are discarded. Consequently, speculative decoding favours lower batch sizes. In practice, we discover that speculative decoding
provides a speed-up until a batch size of 4. Above batch size 4, speculative decoding returns slower inference than the
essential model alone. For full results, check with Section D.3 of the Distil-Whisper paper.
Conclusion
On this blog post, we covered the inference strategy of speculative decoding, as applied to the Whisper model for speech
transcription. We demonstrated how 2x speed-ups may be achieved, while mathematically ensuring the identical outputs as using
the unique model alone. We encourage you to try speculative decoding as a drop-in substitute for existing Whisper
pipelines, given the low overhead of using the extra assistant model and the guarantee of the identical transcription results.
Acknowledgements
Blog post by Sanchit Gandhi. Many because of Patrick von Platen
and Pedro Cuenca for his or her constructive comments, and to Joao Gante
for the assisted generation implementation in 🤗 Transformers.
