Latest (11/2021): This blog post has been updated to feature XLSR’s
successor, called XLS-R.
Wav2Vec2 is a pretrained model for Automatic Speech Recognition
(ASR) and was released in September
2020
by Alexei Baevski, Michael Auli, and Alex Conneau. Soon after the
superior performance of Wav2Vec2 was demonstrated on one of the vital
popular English datasets for ASR, called
LibriSpeech,
Facebook AI presented a multi-lingual version of Wav2Vec2, called
XLSR. XLSR stands for cross-lingual
speech representations and refers to model’s ability to learn speech
representations which might be useful across multiple languages.
XLSR’s successor, simply called XLS-R (refering to the
”XLM-R
for Speech”), was released in November 2021 by Arun
Babu, Changhan Wang, Andros Tjandra, et al. XLS-R used almost half a
million hours of audio data in 128 languages for self-supervised
pre-training and is available in sizes starting from 300 milion as much as two
billion parameters. Yow will discover the pretrained checkpoints on the 🤗
Hub:
Much like BERT’s masked language modeling
objective, XLS-R learns
contextualized speech representations by randomly masking feature
vectors before passing them to a transformer network during
self-supervised pre-training (i.e. diagram on the left below).
For fine-tuning, a single linear layer is added on top of the
pre-trained network to coach the model on labeled data of audio
downstream tasks corresponding to speech recognition, speech translation and
audio classification (i.e. diagram on the best below).
XLS-R shows impressive improvements over previous state-of-the-art
results on each speech recognition, speech translation and
speaker/language identification, cf. with Table 3-6, Table 7-10, and
Table 11-12 respectively of the official paper.
Setup
On this blog, we are going to give an in-detail explanation of how XLS-R –
more specifically the pre-trained checkpoint
Wav2Vec2-XLS-R-300M – might be fine-tuned for ASR.
For demonstration purposes, we fine-tune the model on the low resource
ASR dataset of Common
Voice that accommodates only
ca. 4h of validated training data.
XLS-R is fine-tuned using Connectionist Temporal Classification (CTC),
which is an algorithm that’s used to coach neural networks for
sequence-to-sequence problems, corresponding to ASR and handwriting recognition.
I highly recommend reading the well-written blog post Sequence
Modeling with CTC (2017) by Awni
Hannun.
Before we start, let’s install datasets and transformers. Also, we
need the torchaudio to load audio files and jiwer to judge our
fine-tuned model using the word error rate
(WER) metric .
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer
We strongly suggest to upload your training checkpoints on to the
Hugging Face Hub while training. The Hugging Face
Hub has integrated version control so you possibly can
ensure that no model checkpoint is getting lost during training.
To accomplish that you might have to store your authentication token from the Hugging
Face website (join here in case you
have not already!)
from huggingface_hub import notebook_login
notebook_login()
Print Output:
Login successful
Your token has been saved to /root/.huggingface/token
Then you’ll want to install Git-LFS to upload your model checkpoints:
apt install git-lfs
Within the paper, the model
was evaluated using the phoneme error rate (PER), but by far essentially the most
common metric in ASR is the word error rate (WER). To maintain this notebook
as general as possible we decided to judge the model using WER.
Prepare Data, Tokenizer, Feature Extractor
ASR models transcribe speech to text, which implies that we each need a
feature extractor that processes the speech signal to the model’s input
format, e.g. a feature vector, and a tokenizer that processes the
model’s output format to text.
In 🤗 Transformers, the XLS-R model is thus accompanied by each a
tokenizer, called
Wav2Vec2CTCTokenizer,
and a feature extractor, called
Wav2Vec2FeatureExtractor.
Let’s start by creating the tokenizer to decode the expected output
classes to the output transcription.
Create Wav2Vec2CTCTokenizer
A pre-trained XLS-R model maps the speech signal to a sequence of
context representations as illustrated within the figure above. Nonetheless, for
speech recognition the model has to to map this sequence of context
representations to its corresponding transcription which implies that a
linear layer needs to be added on top of the transformer block (shown in
yellow within the diagram above). This linear layer is used to categorise
each context representation to a token class analogous to how
a linear layer is added on top of BERT’s embeddings
for further classification after pre-training (cf. with ‘BERT’ section of the next blog
post).
after pretraining a linear layer is added on top of BERT’s embeddings
for further classification – cf. with ‘BERT’ section of this blog
post.
The output size of this layer corresponds to the variety of tokens within the
vocabulary, which does not rely on XLS-R’s pretraining task, but
only on the labeled dataset used for fine-tuning. So in step one,
we are going to take a take a look at the chosen dataset of Common Voice and define a
vocabulary based on the transcriptions.
First, let’s go to Common Voice official
website and pick a
language to fine-tune XLS-R on. For this notebook, we are going to use Turkish.
For every language-specific dataset, you’ll find a language code
corresponding to your chosen language. On Common
Voice, search for the sector
“Version”. The language code then corresponds to the prefix before the
underscore. For Turkish, e.g. the language code is "tr".
Great, now we will use 🤗 Datasets’ easy API to download the info. The
dataset name is "common_voice", the configuration name corresponds to
the language code, which is "tr" in our case.
Common Voice has many alternative splits including invalidated, which
refers to data that was not rated as “clean enough” to be considered
useful. On this notebook, we are going to only make use of the splits "train",
"validation" and "test".
Since the Turkish dataset is so small, we are going to merge each the
validation and training data right into a training dataset and only use the
test data for validation.
from datasets import load_dataset, load_metric, Audio
common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")
Many ASR datasets only provide the goal text, 'sentence' for every
audio array 'audio' and file 'path'. Common Voice actually provides
far more details about each audio file, corresponding to the 'accent',
etc. Keeping the notebook as general as possible, we only consider the
transcribed text for fine-tuning.
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
Let’s write a brief function to display some random samples of the
dataset and run it a few times to get a sense for the
transcriptions.
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Cannot pick more elements than there are within the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
Print Output:
| Idx | Sentence |
|---|---|
| 1 | Jonuz, kısa süreli görevi kabul eden tek adaydı. |
| 2 | Biz umudumuzu bu mücadeleden almaktayız. |
| 3 | Sergide beş Hırvat yeniliği sergilendi. |
| 4 | Herşey adıyla bilinmeli. |
| 5 | Kuruluş özelleştirmeye hazır. |
| 6 | Yerleşim yerlerinin manzarası harika. |
| 7 | Olayların failleri bulunamadı. |
| 8 | Fakat bu çabalar boşa çıktı. |
| 9 | Projenin değeri iki virgül yetmiş yedi milyon avro. |
| 10 | Büyük yeniden yapım projesi dört aşamaya bölündü. |
Alright! The transcriptions look fairly clean. Having translated the
transcribed sentences, plainly the language corresponds more to
written-out text than noisy dialogue. This is smart considering that
Common Voice is a
crowd-sourced read speech corpus.
We will see that the transcriptions contain some special characters, such
as ,.?!;:. With out a language model, it is far harder to categorise
speech chunks to such special characters because they do not really
correspond to a characteristic sound unit. E.g., the letter "s" has
a roughly clear sound, whereas the special character "." does
not. Also in an effort to understand the meaning of a speech signal, it’s
often not vital to incorporate special characters within the
transcription.
Let’s simply remove all characters that do not contribute to the
meaning of a word and can’t really be represented by an acoustic sound
and normalize the text.
import re
chars_to_remove_regex = '[,?.!-;:"“%‘”�']'
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
Let’s take a look at the processed text labels again.
show_random_elements(common_voice_train.remove_columns(["path","audio"]))
Print Output:
| Idx | Transcription |
|---|---|
| 1 | birisi beyazlar için dediler |
| 2 | maktouf’un cezası haziran ayında sona erdi |
| 3 | orijinalin aksine kıyafetler çıkarılmadı |
| 4 | bunların toplam değeri yüz milyon avroyu buluyor |
| 5 | masada en az iki seçenek bulunuyor |
| 6 | bu hiç de haksız bir heveslilik değil |
| 7 | bu durum bin dokuz yüz doksanlarda ülkenin bölünmesiyle değişti |
| 8 | söz konusu süre altı ay |
| 9 | ancak bedel çok daha yüksek olabilir |
| 10 | başkent fira bir tepenin üzerinde yer alıyor |
Good! This looks higher. We’ve got removed most special characters from
transcriptions and normalized them to lower-case only.
Before finalizing the pre-processing, it’s all the time advantageous to
seek the advice of a native speaker of the goal language to see whether the text
might be further simplified. For this blog post,
Merve was kind enough to take a fast
look and noted that “hatted” characters – like â – aren’t really
used anymore in Turkish and might be replaced by their “un-hatted”
equivalent, e.g. a.
Which means that we should always replace a sentence like
"yargı sistemi hâlâ sağlıksız" to "yargı sistemi hala sağlıksız".
Let’s write one other short mapping function to further simplify the text
labels. Remember, the simpler the text labels, the better it’s for the
model to learn to predict those labels.
def replace_hatted_characters(batch):
batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)
In CTC, it’s common to categorise speech chunks into letters, so we are going to
do the identical here. Let’s extract all distinct letters of the training
and test data and construct our vocabulary from this set of letters.
We write a mapping function that concatenates all transcriptions into
one long transcription after which transforms the string right into a set of
chars. It will be significant to pass the argument batched=True to the
map(...) function in order that the mapping function has access to all
transcriptions without delay.
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
Now, we create the union of all distinct letters within the training dataset
and test dataset and convert the resulting list into an enumerated
dictionary.
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
Print Output:
{
' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26,
'ç': 27,
'ë': 28,
'ö': 29,
'ü': 30,
'ğ': 31,
'ı': 32,
'ş': 33,
'̇': 34
}
Cool, we see that every one letters of the alphabet occur within the dataset
(which will not be really surprising) and we also extracted the special
characters "" and '. Note that we didn’t exclude those special
characters because:
The model has to learn to predict when a word is finished or else the
model prediction would all the time be a sequence of chars which might make it
unimaginable to separate words from one another.
One should all the time have in mind that pre-processing is an important
step before training your model. E.g., we don’t desire our model to
differentiate between a and A simply because we forgot to normalize
the info. The difference between a and A doesn’t rely on the
“sound” of the letter in any respect, but more on grammatical rules – e.g.
use a capitalized letter firstly of the sentence. So it’s
sensible to remove the difference between capitalized and
non-capitalized letters in order that the model has a neater time learning to
transcribe speech.
To make it clearer that " " has its own token class, we give it a more
visible character |. As well as, we also add an “unknown” token so
that the model can later take care of characters not encountered in Common
Voice’s training set.
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
Finally, we also add a padding token that corresponds to CTC’s “blank
token“. The “blank token” is a core component of the CTC algorithm.
For more information, please take a take a look at the “Alignment” section
here.
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
Cool, now our vocabulary is complete and consists of 39 tokens, which
signifies that the linear layer that we’ll add on top of the pretrained
XLS-R checkpoint can have an output dimension of 39.
Let’s now save the vocabulary as a json file.
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
In a final step, we use the json file to load the vocabulary into an
instance of the Wav2Vec2CTCTokenizer class.
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
If one desires to re-use the just created tokenizer with the fine-tuned
model of this notebook, it’s strongly advised to upload the tokenizer
to the Hugging Face Hub. Let’s call the repo to which
we are going to upload the files "wav2vec2-large-xlsr-turkish-demo-colab":
repo_name = "wav2vec2-large-xls-r-300m-tr-colab"
and upload the tokenizer to the 🤗 Hub.
tokenizer.push_to_hub(repo_name)
Great, you possibly can see the just created repository under
https://huggingface.co/
Create Wav2Vec2FeatureExtractor
Speech is a continuous signal, and, to be treated by computers, it first
needs to be discretized, which is frequently called sampling. The
sampling rate hereby plays a very important role because it defines what number of
data points of the speech signal are measured per second. Subsequently,
sampling with the next sampling rate leads to a greater approximation
of the real speech signal but additionally necessitates more values per
second.
A pretrained checkpoint expects its input data to have been sampled more
or less from the identical distribution as the info it was trained on. The
same speech signals sampled at two different rates have a really different
distribution. For instance, doubling the sampling rate leads to data points
being twice as long. Thus, before fine-tuning a pretrained checkpoint of
an ASR model, it’s crucial to confirm that the sampling rate of the info
that was used to pretrain the model matches the sampling rate of the
dataset used to fine-tune the model.
XLS-R was pretrained on audio data of
Babel,
Multilingual LibriSpeech
(MLS),
Common Voice,
VoxPopuli, and
VoxLingua107 at a sampling rate of
16kHz. Common Voice, in its original form, has a sampling rate of 48kHz,
thus we can have to downsample the fine-tuning data to 16kHz within the
following.
A Wav2Vec2FeatureExtractor object requires the next parameters to
be instantiated:
feature_size: Speech models take a sequence of feature vectors as
an input. While the length of this sequence obviously varies, the
feature size shouldn’t. Within the case of Wav2Vec2, the feature size
is 1 since the model was trained on the raw speech signal .sampling_rate: The sampling rate at which the model is trained on.padding_value: For batched inference, shorter inputs must be
padded with a selected valuedo_normalize: Whether the input must be
zero-mean-unit-variance normalized or not. Often, speech models
perform higher when normalizing the inputreturn_attention_mask: Whether the model should make use of an
attention_maskfor batched inference. Generally, XLS-R models
checkpoints should all the time use theattention_mask.
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
Great, XLS-R’s feature extraction pipeline is thereby fully defined!
For improved user-friendliness, the feature extractor and tokenizer are
wrapped right into a single Wav2Vec2Processor class in order that one only needs
a model and processor object.
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
Next, we will prepare the dataset.
Preprocess Data
To date, now we have not checked out the actual values of the speech signal but
just the transcription. Along with sentence, our datasets include
two more column names path and audio. path states absolutely the
path of the audio file. Let’s have a look.
common_voice_train[0]["path"]
XLS-R expects the input within the format of a 1-dimensional array of 16
kHz. Which means that the audio file needs to be loaded and resampled.
Thankfully, datasets does this routinely by calling the opposite
column audio. Let try it out.
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 48000}
Great, we will see that the audio file has routinely been loaded.
That is because of the brand new "Audio"
feature
introduced in datasets == 1.18.3, which loads and resamples audio
files on-the-fly upon calling.
In the instance above we will see that the audio data is loaded with a
sampling rate of 48kHz whereas 16kHz are expected by the model. We will
set the audio feature to the right sampling rate by making use of
cast_column:
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
Let’s take a take a look at "audio" again.
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 16000}
This appeared to have worked! Let’s hearken to a few audio files to
higher understand the dataset and confirm that the audio was accurately
loaded.
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(common_voice_train)-1)
print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)
Print Output:
sunulan bütün teklifler i̇ngilizce idi
It looks like the info is now accurately loaded and resampled.
It may possibly be heard, that the speakers change together with their speaking
rate, accent, and background environment, etc. Overall, the recordings
sound acceptably clear though, which is to be expected from a
crowd-sourced read speech corpus.
Let’s do a final check that the info is accurately prepared, by printing
the form of the speech input, its transcription, and the corresponding
sampling rate.
rand_int = random.randint(0, len(common_voice_train)-1)
print("Goal text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
Print Output:
Goal text: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
Input array shape: (71040,)
Sampling rate: 16000
Good! Every part looks effective – the info is a 1-dimensional array, the
sampling rate all the time corresponds to 16kHz, and the goal text is
normalized.
Finally, we will leverage Wav2Vec2Processor to process the info to the
format expected by Wav2Vec2ForCTC for training. To accomplish that let’s make
use of Dataset’s
map(...)
function.
First, we load and resample the audio data, just by calling
batch["audio"]. Second, we extract the input_values from the loaded
audio file. In our case, the Wav2Vec2Processor only normalizes the
data. For other speech models, nevertheless, this step can include more
complex feature extraction, corresponding to Log-Mel feature
extraction.
Third, we encode the transcriptions to label ids.
Note: This mapping function is a superb example of how the
Wav2Vec2Processor class must be used. In “normal” context, calling
processor(...) is redirected to Wav2Vec2FeatureExtractor‘s call
method. When wrapping the processor into the as_target_processor
context, nevertheless, the identical method is redirected to
Wav2Vec2CTCTokenizer‘s call method. For more information please check
the
docs.
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
Let’s apply the info preparation function to all examples.
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
Note: Currently datasets make use of
torchaudio and
librosa for audio loading
and resampling. In the event you want to implement your individual costumized data
loading/sampling, be happy to only make use of the "path" column
as an alternative and disrespect the "audio" column.
Long input sequences require plenty of memory. XLS-R relies on
self-attention. The memory requirement scales quadratically with the
input length for long input sequences (cf. with
this
reddit post). In case this demo crashes with an “Out-of-memory” error
for you, it is advisable to uncomment the next lines to filter all
sequences which might be longer than 5 seconds for training.
Awesome, now we’re ready to begin training!
Training
The info is processed in order that we’re ready to begin organising the
training pipeline. We’ll make use of 🤗’s
Trainer
for which we essentially have to do the next:
-
Define an information collator. In contrast to most NLP models, XLS-R has a
much larger input length than output length. E.g., a sample of
input length 50000 has an output length of not more than 100. Given
the big input sizes, it’s far more efficient to pad the training
batches dynamically meaning that every one training samples should only be
padded to the longest sample of their batch and never the general
longest sample. Subsequently, fine-tuning XLS-R requires a special
padding data collator, which we are going to define below -
Evaluation metric. During training, the model must be evaluated on
the word error rate. We should always define acompute_metricsfunction
accordingly -
Load a pretrained checkpoint. We’d like to load a pretrained
checkpoint and configure it accurately for training. -
Define the training configuration.
After having fine-tuned the model, we are going to accurately evaluate it on the
test data and confirm that it has indeed learned to accurately transcribe
speech.
Set-up Trainer
Let’s start by defining the info collator. The code for the info
collator was copied from this
example.
Without going into too many details, in contrast to the common data
collators, this data collator treats the input_values and labels
in another way and thus applies to separate padding functions on them
(again making use of XLS-R processor’s context manager). That is
vital because in speech input and output are of various modalities
meaning that they shouldn’t be treated by the identical padding function.
Analogous to the common data collators, the padding tokens within the labels
with -100 in order that those tokens are not taken under consideration when
computing the loss.
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that may dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for processing the info.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a method to pad the returned sequences (in keeping with the model's padding side and padding index)
amongst:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence within the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument will not be provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
Next, the evaluation metric is defined. As mentioned earlier, the
predominant metric in ASR is the word error rate (WER), hence we are going to
use it on this notebook as well.
wer_metric = load_metric("wer")
The model will return a sequence of logit vectors: with and .
A logit vector accommodates the log-odds for every word within the
vocabulary we defined earlier, thus
config.vocab_size. We’re fascinated by the probably prediction of
the model and thus take the argmax(...) of the logits. Also, we
transform the encoded labels back to the unique string by replacing
-100 with the pad_token_id and decoding the ids while ensuring
that consecutive tokens are not grouped to the identical token in CTC
style .
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Now, we will load the pretrained checkpoint of
Wav2Vec2-XLS-R-300M.
The tokenizer’s pad_token_id have to be to define the model’s
pad_token_id or within the case of Wav2Vec2ForCTC also CTC’s blank
token . To save lots of GPU memory, we enable PyTorch’s gradient
checkpointing and likewise
set the loss reduction to “mean“.
Since the dataset is sort of small (~6h of coaching data) and since
Common Voice is sort of noisy, fine-tuning Facebook’s
wav2vec2-xls-r-300m checkpoint seems to require some
hyper-parameter tuning. Subsequently, I needed to mess around a bit with
different values for dropout,
SpecAugment‘s masking dropout rate,
layer dropout, and the training rate until training appeared to be stable
enough.
Note: When using this notebook to coach XLS-R on one other language of
Common Voice those hyper-parameter settings may not work thoroughly.
Be at liberty to adapt those depending in your use case.
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
The primary component of XLS-R consists of a stack of CNN layers which might be
used to extract acoustically meaningful – but contextually independent –
features from the raw speech signal. This a part of the model has already
been sufficiently trained during pretraining and as stated within the
paper doesn’t must be
fine-tuned anymore. Thus, we will set the requires_grad to False for
all parameters of the feature extraction part.
model.freeze_feature_extractor()
In a final step, we define all parameters related to training. To offer
more explanation on among the parameters:
group_by_lengthmakes training more efficient by grouping training
samples of comparable input length into one batch. This could
significantly speed up training time by heavily reducing the general
variety of useless padding tokens which might be passed through the modellearning_rateandweight_decaywere heuristically tuned until
fine-tuning has develop into stable. Note that those parameters strongly
rely on the Common Voice dataset and is perhaps suboptimal for other
speech datasets.
For more explanations on other parameters, one can take a take a look at the
docs.
During training, a checkpoint might be uploaded asynchronously to the Hub
every 400 training steps. It permits you to also mess around with the
demo widget even while your model continues to be training.
Note: If one doesn’t wish to upload the model checkpoints to the
Hub, simply set push_to_hub=False.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
Now, all instances might be passed to Trainer and we’re ready to begin
training!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
To permit models to develop into independent of the speaker rate, in
CTC, consecutive tokens which might be an identical are simply grouped as a
single token. Nonetheless, the encoded labels shouldn’t be grouped when
decoding since they do not correspond to the expected tokens of the
model, which is why the group_tokens=False parameter needs to be passed.
If we would not pass this parameter a word like "hello" would
incorrectly be encoded, and decoded as "helo".
The blank token allows the model to predict a word, corresponding to
"hello" by forcing it to insert the blank token between the 2 l’s.
A CTC-conform prediction of "hello" of our model can be
[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].
Training
Training will take multiple hours depending on the GPU allocated to this
notebook. While the trained model yields somewhat satisfying results on
Common Voice‘s test data of Turkish, it’s under no circumstances an optimally
fine-tuned model. The aim of this notebook is simply to reveal
find out how to fine-tune XLS-R XLSR-Wav2Vec2’s on an ASR dataset.
Depending on what GPU was allocated to your google colab it is perhaps
possible that you simply are seeing an "out-of-memory" error here. On this
case, it’s probably best to cut back per_device_train_batch_size to eight
and even less and increase
gradient_accumulation.
trainer.train()
Print Output:
| Training Loss | Epoch | Step | Validation Loss | Wer |
|---|---|---|---|---|
| 3.8842 | 3.67 | 400 | 0.6794 | 0.7000 |
| 0.4115 | 7.34 | 800 | 0.4304 | 0.4548 |
| 0.1946 | 11.01 | 1200 | 0.4466 | 0.4216 |
| 0.1308 | 14.68 | 1600 | 0.4526 | 0.3961 |
| 0.0997 | 18.35 | 2000 | 0.4567 | 0.3696 |
| 0.0784 | 22.02 | 2400 | 0.4193 | 0.3442 |
| 0.0633 | 25.69 | 2800 | 0.4153 | 0.3347 |
| 0.0498 | 29.36 | 3200 | 0.4077 | 0.3195 |
The training loss and validation WER go down nicely.
You possibly can now upload the results of the training to the Hub, just execute
this instruction:
trainer.push_to_hub()
You possibly can now share this model with all your mates, family, favorite
pets: they will all load it with the identifier
“your-username/the-name-you-picked” so for example:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
For more examples of how XLS-R might be fine-tuned, please take a take a look at the official
🤗 Transformers examples.
Evaluation
As a final check, let’s load the model and confirm that it indeed has
learned to transcribe Turkish speech.
Let’s first load the pretrained checkpoint.
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
Now, we are going to just take the primary example of the test set, run it through
the model and take the argmax(...) of the logits to retrieve the
predicted token ids.
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
It’s strongly really useful to pass the sampling_rate argument to this function.Failing to accomplish that may end up in silent errors that is perhaps hard to debug.
We adapted common_voice_test quite a bit in order that the dataset instance
doesn’t contain the unique sentence label anymore. Thus, we re-use
the unique dataset to get the label of the primary example.
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
Finally, we will decode the instance.
print("Prediction:")
print(processor.decode(pred_ids))
print("nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())
Print Output:
| pred_str | target_text |
|---|---|
| hatta küçük şeyleri için bir büyt bir şeyleri kolluyor veyınıki çuk şeyler için bir bir mizi inciltiyoruz | hayatta küçük şeyleri kovalıyor ve yine küçük şeyler için birbirimizi incitiyoruz. |
Alright! The transcription can definitely be recognized from our
prediction, nevertheless it will not be perfect yet. Training the model a bit longer,
spending more time on the info preprocessing, and particularly using a
language model for decoding will surely improve the model’s overall
performance.
For an illustration model on a low-resource language, the outcomes are
quite acceptable nevertheless 🤗.

