Effective-Tune W2V2-Bert for low-resource ASR with 🤗 Transformers

-


Yoach Lacombe's avatar


Open In Colab

Latest (01/2024): This blog post is strongly inspired by “Effective-tuning XLS-R on Multi-Lingual ASR” and “Effective-tuning MMS Adapter Models for Multi-Lingual ASR”.



Introduction

Last month, MetaAI released Wav2Vec2-BERT, as a constructing block of their Seamless Communication, a family of AI translation models.

Wav2Vec2-BERT is the results of a series of improvements based on an original model: Wav2Vec2, a pre-trained model for Automatic Speech Recognition (ASR) released in September 2020 by Alexei Baevski, Michael Auli, and Alex Conneau. With as little as 10 minutes of labeled audio data, Wav2Vec2 may very well be fine-tuned to realize 5% word-error rate performance on the LibriSpeech dataset, demonstrating for the primary time low-resource transfer learning for ASR.

Following a series of multilingual improvements (XLSR, XLS-R and MMS), Wav2Vec2-BERT is a 580M-parameters versatile audio model that has been pre-trained on 4.5M hours of unlabeled audio data covering greater than 143 languages. For comparison, XLS-R used almost half 1,000,000 hours of audio data in 128 languages and MMS checkpoints were pre-trained on greater than half 1,000,000 hours of audio in over 1,400 languages. Boosting to tens of millions of hours enables Wav2Vec2-BERT to realize much more competitive leads to speech-related tasks, regardless of the language.

To make use of it for ASR, Wav2Vec2-BERT will be fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that’s used to coach neural networks for sequence-to-sequence problems, reminiscent of ASR and handwriting recognition. We highly recommend reading the well-written blog post Sequence Modeling with CTC (2017) by Awni Hannun, to learn more concerning the CTC algorithm.

The aim of this notebook is to provide you all the weather it is advisable train Wav2Vec2-BERT model – more specifically the pre-trained checkpoint facebook/w2v-bert-2.0 – on ASR tasks, using open-source tools and models. It first presents the whole pre-processing pipeline, then performs just a little fine-tuning of the W2V2-BERT. The ultimate section gathers training suggestions from Hugging Face experts to scale-up CTC training.

For demonstration purposes, we fine-tune the model on the low resource Mongolian ASR dataset of Common Voice 16.0 that accommodates ca. 14h of validated training data.



Motivation

Whisper is a collection of ASR models, commonly accepted as the perfect performing models for the ASR task. It provides state-of-the-art performance for English ASR, while being well suited to multilingual fine-tuning from limited resources.

Nonetheless, relating to “resource-poor” languages reminiscent of Mongolian, Whisper performs poorly, as seen in section D.2.2 of the Whisper paper – Mongolian or Malayalam achieved over 100% WER at every Whisper checkpoint. The checkpoint available even have a limited vocabulary and due to this fact can’t be fine-tuned on a language whose alphabet doesn’t overlap with this vocabulary.

As well as, Whisper is a sequence-to-sequence model that performs ASR autoregressively, making it inherently “slow”. Whisper’s slowness is exacerbated for languages whose characteristics are infrequent within the training dataset. On this case, Whisper has to generate on average more tokens per word, and due to this fact takes longer.

Faced with limited resources – each when it comes to training data availability and inference constraints – more “frugal” models are needed. On this case, Wav2Vec2-BERT is just the thing.

Wav2Vec2-BERT predicts ASR in a single pass, making it much faster than Whisper. As this notebook will show, it requires little data to realize competitive performance, is easily adaptable to any alphabet, and is more resource-efficient.

In actual fact, it achieves similar WER performance on Mongolian ASR compared with Whisper-large-v3 after similar fine-tuning, while being over 10x to 30x faster and 2.5x more resource-efficient.

Note: The benchmark was carried out with a 16GB V100 on Google Colab, using batch sizes starting from 1 to eight on the Mongolian CV16 test set.



Notebook Setup

Before we start, let’s install datasets and transformers. Also, we’d like speed up for training, torchaudio to load audio files and jiwer to judge our fine-tuned model using the word error rate (WER) metric.

%%capture
!pip install datasets
!pip install --upgrade transformers
!pip install torchaudio
!pip install jiwer
!pip install speed up -U

We strongly suggest to upload your training checkpoints on to the 🤗 Hub while training. The 🤗 Hub provides:

  • Integrated version control: you’ll be able to make sure that no model checkpoint is lost during training.
  • Tensorboard logs: track essential metrics over the course of coaching.
  • Model cards: document what a model does and its intended use cases.
  • Community: a straightforward method to share and collaborate with the community!

To achieve this, you’ve gotten to store your authentication token from the Hugging Face website (join here for those who have not already!). This is finished by entering your Hub authentication token when prompted below. Find your Hub authentication token here:

from huggingface_hub import notebook_login

notebook_login()


Prepare Data, Tokenizer, Feature Extractor

ASR models transcribe speech to text, which implies that we each need a feature extractor that processes the speech signal to the model’s input format, e.g. a feature vector, and a tokenizer that processes the model’s output format to text.

In 🤗 Transformers, the Wav2Vec2-BERT model is thus accompanied by each a tokenizer, called Wav2Vec2CTCTokenizer, and a feature extractor, called SeamlessM4TFeatureExtractor that the model shares with the first and second versions of Seamless-M4T, as all of them process audio in the identical way.

Let’s start by creating the tokenizer to decode the expected output classes to the output transcription.



Create Wav2Vec2CTCTokenizer

Keep in mind that Wav2Vec2-like models fine-tuned on CTC transcribe an audio file with a single forward pass by first processing the audio input right into a sequence of processed context representations after which using the ultimate vocabulary output layer to categorise each context representation to a personality that represents the transcription.

The output size of this layer corresponds to the variety of tokens within the vocabulary, and due to this fact only on the labeled dataset used for fine-tuning. So in step one, we are going to take a have a look at the chosen dataset of Common Voice and define a vocabulary based on the transcriptions.

For this notebook, we are going to use Common Voice’s 16.0 dataset for Mongolian. Mongolian corresponds to the language code "mn".

Now we will use 🤗 Datasets’ easy API to download the info. The dataset name is "mozilla-foundation/common_voice_16_0", the configuration name corresponds to the language code, which is "mn" in our case.

Note: Before with the ability to download the dataset, you’ve gotten to access it by logging into your Hugging Face account, happening the dataset repo page and clicking on “Agree and Access repository”

Common Voice has many alternative splits including invalidated, which refers to data that was not rated as “clean enough” to be considered useful. On this notebook, we are going to only make use of the splits "train", "validation" and "test".

Since the Mongolian dataset is so small, we are going to merge each the validation and training data right into a training dataset and only use the test data for validation.

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="test", use_auth_token=True)

Many ASR datasets only provide the goal text, 'sentence' for every audio array 'audio' and file 'path'. Common Voice actually provides rather more details about each audio file, reminiscent of the 'accent', etc. Keeping the notebook as general as possible, we only consider the transcribed text for fine-tuning.

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

Let’s write a brief function to display some random samples of the dataset and run it a few times to get a sense for the transcriptions.

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Cannot pick more elements than there are within the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_train.remove_columns(["path", "audio"]), num_examples=10)

Alright! The transcriptions look fairly clean. Having translated the transcribed sentences, plainly the language corresponds more to written-out text than noisy dialogue. This is sensible considering that Common Voice is a crowd-sourced read speech corpus.

We will see that the transcriptions contain some special characters, reminiscent of ,.?!;:. With no language model, it is far harder to categorise speech chunks to such special characters because they do not really correspond to a characteristic sound unit. E.g., the letter "s" has a roughly clear sound, whereas the special character "." doesn’t.
Also in an effort to understand the meaning of a speech signal, it is normally not obligatory to incorporate special characters within the transcription.

Let’s simply remove all characters that do not contribute to the meaning of a word and can’t really be represented by an acoustic sound and normalize the text.

import re
chars_to_remove_regex = '[,?.!-;:"“%‘”�'»«]'

def remove_special_characters(batch):
    
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()

    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

Let us take a look at the processed text labels again.

show_random_elements(common_voice_train.remove_columns(["path","audio"]))
Хойч үе юуны төлөө тэмцэлдэхийг би мэдэхгүй.	
Тэр өвдгөн дээрээ толгойгоо тавиад сулхан гиншинэ.	
Эхнэргүй ганц бие хүн гэсэн санагдана.	
Дамиран хотод төрж өссөн хээнцэр залуусын нэг билээ.	
Мөн судлаачид шинжлэх ухааны үндэстэй тайлбар хайдаг.	
Судалгааны ажил нь бүтэлгүй болсонд л гутарч маргааш илүү ажиллах тухай бодсон бололтой.	
Ийм зөрчлөөс гэтлэх гарц "Оноосон нэрийн сан"-г үүсгэснээр шийдвэрлэгдэнэ.	
Үүлтэй тэнгэрийн доогуур үзүүртэй моддын дээгүүр дүүлэн нисэх сэн.	
Та нар ямар юмаа ингэж булаацалдаа вэ?	
Тэд амьд хэлтрээ болов уу яагаа бол гэхээс одоо ч дотор арзганан бачуурдаг юм.	

In CTC, it is not uncommon to categorise speech chunks into letters, so we are going to do the identical here.
Let’s extract all distinct letters of the training and test data and construct our vocabulary from this set of letters.

We write a mapping function that concatenates all transcriptions into one long transcription after which transforms the string right into a set of chars.
It is necessary to pass the argument batched=True to the map(...) function in order that the mapping function has access to all transcriptions without delay.

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

Now, we create the union of all distinct letters within the training dataset and test dataset and convert the resulting list into an enumerated dictionary.

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'l': 9,
 'n': 10,
 'o': 11,
 'r': 12,
 't': 13,
 'x': 14,
 'а': 15,
 'б': 16,
 'в': 17,
 'г': 18,
 'д': 19,
 'е': 20,
 'ж': 21,
 'з': 22,
 'и': 23,
 'й': 24,
 'к': 25,
 'л': 26,
 'м': 27,
 'н': 28,
 'о': 29,
 'п': 30,
 'р': 31,
 'с': 32,
 'т': 33,
 'у': 34,
 'ф': 35,
 'х': 36,
 'ц': 37,
 'ч': 38,
 'ш': 39,
 'ъ': 40,
 'ы': 41,
 'ь': 42,
 'э': 43,
 'ю': 44,
 'я': 45,
 'ё': 46,
 'ү': 47,
 'ө': 48}

Cleansing up a dataset is a back-and-forth process that should be done with care.

the separate letters within the training and test datasets, we see a mixture of Latin and Mongolian Cyrillic characters. After discussing with a native speaker of the goal language (thanks Mishig for having a look), we’ll remove the Latin characters for 2 reasons:

  1. the CTC algorithm advantages from reduced vocabulary size, so it is suggested to remove redundant characters
  2. in this instance, we’re concentrating entirely on the Mongolian alphabet.
def remove_latin_characters(batch):
    batch["sentence"] = re.sub(r'[a-z]+', '', batch["sentence"])
    return batch


common_voice_train = common_voice_train.map(remove_latin_characters)
common_voice_test = common_voice_test.map(remove_latin_characters)


vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'а': 1,
 'б': 2,
 'в': 3,
 'г': 4,
 'д': 5,
 'е': 6,
 'ж': 7,
 'з': 8,
 'и': 9,
 'й': 10,
 'к': 11,
 'л': 12,
 'м': 13,
 'н': 14,
 'о': 15,
 'п': 16,
 'р': 17,
 'с': 18,
 'т': 19,
 'у': 20,
 'ф': 21,
 'х': 22,
 'ц': 23,
 'ч': 24,
 'ш': 25,
 'ъ': 26,
 'ы': 27,
 'ь': 28,
 'э': 29,
 'ю': 30,
 'я': 31,
 'ё': 32,
 'ү': 33,
 'ө': 34}

Cool, we see that every one letters of the Mongolian alphabet occur within the dataset (which is just not really surprising) and we also extracted the special character " ". Note that we didn’t exclude this special character because:
the model has to learn to predict when a word is finished or else the model prediction would all the time be a sequence of chars which might make it unattainable to separate words from one another.

One should all the time have in mind that pre-processing is a vital step before training your model. E.g., we don’t need our model to distinguish between a and A simply because we forgot to normalize the info. The difference between a and A doesn’t depend upon the “sound” of the letter in any respect, but more on grammatical rules – e.g. use a capitalized letter at the start of the sentence. So it is wise to remove the difference between capitalized and non-capitalized letters in order that the model has a better time learning to transcribe speech. You may read more concerning the effects of pre-processing on the ASR task within the Audio Transformers Course.

To make it clearer that " " has its own token class, we give it a more visible character |. As well as, we also add an “unknown” token in order that the model can later cope with characters not encountered in Common Voice’s training set.

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

Finally, we also add a padding token that corresponds to CTC’s “blank token“. The “blank token” is a core component of the CTC algorithm. For more information, please take a have a look at the “Alignment” section of this blog post.

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
37

Cool, now our vocabulary is complete and consists of 37 tokens, which implies that the linear layer that we are going to add on top of the pre-trained Wav2Vec2-BERT checkpoint may have an output dimension of 37.

Let’s now save the vocabulary as a json file.

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In a final step, we use the json file to load the vocabulary into an instance of the Wav2Vec2CTCTokenizer class

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

If one desires to re-use the just created tokenizer with the fine-tuned model of this notebook, it’s strongly advised to upload the tokenizer to the 🤗 Hub. Let’s call the repo to which we are going to upload the files
"w2v-bert-2.0-mongolian-colab-CV16.0":

repo_name = "w2v-bert-2.0-mongolian-colab-CV16.0"

and upload the tokenizer to the 🤗 Hub.

tokenizer.push_to_hub(repo_name)

Great, you’ll be able to see the just created repository under https://huggingface.co//w2v-bert-2.0-mongolian-colab-CV16.0


Create SeamlessM4TFeatureExtractor

The role of the SeamlessM4TFeatureExtractor is to arrange the raw audio input in a format that the model can “understand”. It due to this fact maps the sequence of one-dimensional amplitude values (aka the raw audio input) to a two-dimensional matrix of log-mel spectrogram values. The latter encodes the signal frequency information as a function of time. See this section from the Audio Transformers course to learn more about spectrograms and why they’re essential.

Unlike the tokenizer, the feature extractor doesn’t must be “learned” from the info, so we will load it directly from the initial model checkpoint.

from transformers import SeamlessM4TFeatureExtractor

feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

Great, Wav2Vec2-BERT’s feature extraction pipeline is thereby fully defined!

For improved user-friendliness, the feature extractor and tokenizer are wrapped right into a single Wav2Vec2BertProcessor class in order that one only needs a model and processor object.

from transformers import Wav2Vec2BertProcessor

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.push_to_hub(repo_name)

Next, we will prepare the dataset.



Preprocess Data

To this point, now we have not checked out the actual values of the speech signal but just the transcription. Along with sentence, our datasets include two more column names path and audio. path states absolutely the path of the audio file. Let’s have a look.

common_voice_train[0]["path"]
/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3

Wav2Vec2-BERT expects the input within the format of a 1-dimensional array of 16 kHz. Which means the audio file must be loaded and resampled.

Thankfully, datasets does this mechanically by calling the opposite column audio. Let try it out.

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 0.00000000e+00, -1.64773251e-14,  1.81765166e-13, ...,
        -3.23167333e-05,  2.20304846e-05,  3.26883201e-05]),
 'sampling_rate': 48000}

Great, we will see that the audio file has mechanically been loaded. That is because of the brand new "Audio" feature introduced in datasets == 4.13.3, which loads and resamples audio files on-the-fly upon calling.

In the instance above we will see that the audio data is loaded with a sampling rate of 48kHz whereas Wav2Vec2-BERT was pre-trained at a sampling rate of 16kHz. The sampling rate plays a vital role in that it defines what number of data points of the speech signal are measured per second. Subsequently, sampling with the next sampling rate leads to a greater approximation of the real speech signal but in addition necessitates more values per second.

A pre-trained checkpoint expects its input data to have been sampled roughly from the identical distribution as the info it was trained on. The identical speech signals sampled at two different rates have a really different distribution, e.g., doubling the sampling rate leads to data points being twice as long. Thus,
before fine-tuning a pre-trained checkpoint of an ASR model, it’s crucial to confirm that the sampling rate of the info that was used to pre-train the model matches the sampling rate of the dataset used to fine-tune the model.

Luckily, we will set the audio feature to the right sampling rate by making use of cast_column:

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

Let’s take a have a look at "audio" again:

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 9.09494702e-12, -2.27373675e-13,  5.45696821e-12, ...,
        -5.22854862e-06, -1.21556368e-05, -9.76262163e-06]),
 'sampling_rate': 16000}

This looked as if it would have worked! Let’s hearken to a few audio files to higher understand the dataset and confirm that the audio was appropriately loaded.

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

It looks as if the info is now appropriately loaded and resampled.

It could actually be heard, that the speakers change together with their speaking rate, accent, and background environment, etc. Overall, the recordings sound acceptably clear though, which is to be expected from a crowd-sourced read speech corpus.

Let’s do a final check that the info is appropriately prepared, by printing the form of the speech input, its transcription, and the corresponding sampling rate.

rand_int = random.randint(0, len(common_voice_train)-1)

print("Goal text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
Goal text: энэ бол тэдний амжилтын бодит нууц
Input array shape: (74496,)
Sampling rate: 16000

Good! All the things looks tremendous – the info is a 1-dimensional array, the sampling rate all the time corresponds to 16kHz, and the goal text is normalized.

Finally, we will leverage Wav2Vec2BertProcessor to process the info to the format expected by Wav2Vec2BertForCTC for training. To achieve this let’s make use of Dataset’s map(...) function.

First, we load and resample the audio data, just by calling batch["audio"].
Second, we extract the input_features from the loaded audio file. In our case, the Wav2Vec2BertProcessor creates a more complex representation because the raw waveform, often called Log-Mel feature extraction.
Third, we encode the transcriptions to label ids.

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(batch["input_features"])

    batch["labels"] = processor(text=batch["sentence"]).input_ids
    return batch

Let’s apply the info preparation function to all examples.

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

Note**: datasets mechanically takes care of audio loading and resampling. In case you want to implement your personal costumized data loading/sampling, be happy to simply make use of the "path" column as an alternative and disrespect the "audio" column.

Awesome, now we’re ready to start out training!



Training

The info is processed in order that we’re ready to start out organising the training pipeline. We’ll make use of 🤗 Transformer’s Trainer class, for which we essentially must do the next:

  • Define an information collator. In contrast to most NLP models, Wav2Vec2-BERT has a much larger input length than output length. Given the massive input sizes, it’s rather more efficient to pad the training batches dynamically meaning that every one training samples should only be padded to the longest sample of their batch and never the general longest sample. Subsequently, fine-tuning Wav2Vec2-BERT requires a special padding data collator, which we are going to define below.

  • Evaluation metric. During training, the model needs to be evaluated on the word error rate. We should always define a compute_metrics function accordingly

  • Load a pre-trained checkpoint. We’d like to load a pre-trained checkpoint and configure it appropriately for training.

  • Define the training configuration.

After having fine-tuned the model, we are going to appropriately evaluate it on the test data and confirm that it has indeed learned to appropriately transcribe speech.



Set-up Trainer

Let’s start by defining the info collator. The code for the info collator was copied from this instance.

Without going into too many details, in contrast to the common data collators, this data collator treats the input_features and labels in another way and thus applies to separate padding functions on them. That is obligatory because in speech input and output are of various modalities meaning that they mustn’t be treated by the identical padding function.
Analogous to the common data collators, the padding tokens within the labels with -100 in order that those tokens are not taken into consideration when computing the loss.

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        
        
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )
        
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the
predominant metric in ASR is the word error rate (WER), hence we are going to use it on this notebook as well.

wer_metric = load_metric("wer")

The model will return a sequence of logit vectors: y1,,ym mathbf{y}_1, ldots, mathbf{y}_m

A logit vector y1 mathbf{y}_1

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we will load the important pre-trained checkpoint. The tokenizer’s pad_token_id should be to define the model’s pad_token_id or within the case of Wav2Vec2BertForCTC also CTC’s blank token 2 {}^2

Since, we’re only training a small subset of weights, the model is just not susceptible to overfitting. Subsequently, we ensure that to disable all dropout layers.

Note: When using this notebook to coach Wav2Vec2-BERT on one other language of Common Voice those hyper-parameter settings may not work thoroughly. Be happy to adapt those depending in your use case.

from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

In a final step, we define all parameters related to training.
To present more explanation on among the parameters:

  • group_by_length makes training more efficient by grouping training samples of comparable input length into one batch. This may significantly speed up training time by heavily reducing the general variety of useless padding tokens which can be passed through the model
  • learning_rate was heuristically tuned until fine-tuning has develop into stable. Note that those parameters strongly depend upon the Common Voice dataset and is likely to be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a have a look at the docs.

During training, a checkpoint might be uploaded asynchronously to the hub every 600 training steps. It permits you to also mess around with the demo widget even while your model continues to be training.

Note: If one doesn’t need to upload the model checkpoints to the hub, simply set push_to_hub=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=600,
  eval_steps=300,
  logging_steps=300,
  learning_rate=5e-5,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

Now, all instances will be passed to Trainer and we’re ready to start out training!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1



Training

Training will take multiple hours depending on the GPU allocated to this notebook. While the trained model yields somewhat satisfying results on Common Voice‘s test data of Mongolian, it’s in no way an optimally fine-tuned model. The aim of this notebook is simply to reveal how one can fine-tune Wav2Vec2-BERT on an ASR dataset.

trainer.train()
Step Training Loss Validation Loss Wer
300 1.712700 0.647740 0.517892
600 0.349300 0.615849 0.442027
900 0.180500 0.525088 0.367305
1200 0.075400 0.528768 0.324016

The training loss and validation WER go down nicely. As compared, the identical training with whisper-large-v3, the commonly recognized state-of-the-art ASR model from OpenAI, has a final WER of 33.3%. You could find the resulting Whisper checkpoint here. This shows that Wav2Vec2-Bert can achieve performance near or such as that of the cutting-edge in low-resource languages.

You may now upload the results of the training to the 🤗 Hub, just execute this instruction:

trainer.push_to_hub()

You may now share this model with all your pals, family, favorite pets: they will all load it with the identifier “your-username/the-name-you-picked” so as an illustration:

from transformers import AutoModelForCTC, Wav2Vec2BertProcessor

model = AutoModelForCTC.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")
processor = Wav2Vec2BertProcessor.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")

For more examples of how Wav2Vec2-BERT will be fine-tuned, please take a have a look at the official speech recognition examples.



Evaluation

As a final check, let’s load the model and confirm that it indeed has learned to transcribe Mongolian speech.

Let’s first load the pre-trained checkpoint.

model = Wav2Vec2BertForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2BertProcessor.from_pretrained(repo_name)

Let’s process the audio, run a forward pass and predict the ids.

sample = common_voice_test[0]
input_features = torch.tensor(sample["input_features"]).to("cuda").unsqueeze(0)

with torch.no_grad():
    logits = model(input_features).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

Finally, we will decode the instance from the expected tokens and compare it to the reference transcription:

print(processor.decode(pred_ids))
print(processor.decode(sample["labels"]).lower())
эрчүүдийн ганцаардлыг эмэхтэйчүүд ойлгох нь ховор юм
эрчүдийн ганцардлыг эмэгтэйчүд ойлгох нь ховор юм

Alright! The transcription can definitely be recognized from our prediction, but it surely is just not perfect yet. Training the model a bit longer, spending more time on the info pre-processing, and particularly using a language model for decoding would definitely improve the model’s overall performance.

For an illustration model on a low-resource language, the outcomes are quite acceptable nonetheless 🤗.



Scaling-up the training

We have shown on this blogpost how Meta’s w2v-bert-2.0 fine-tuning can provide near state-of-the-art performance on low-resource languages.

To take things a step further, I’ve put together a set of suggestions and pointers given by my colleagues at Hugging Face on how one can scale up training for this model. The following tips got here to light once I showed them this blog post training run, in addition to other training attempts (here and here).

Many because of Patrick, Sanchit and Pablo for his or her beneficial expertise and help 🤗

Note that Common Voice newest version (CV16) provides many more hours of information and for may languages and thus provides fertile ground for rather more efficient models in lots of low-resource languages.



Datasets-related suggestions

CTC ASR is often done with lower-case, un-punctuated transcriptions. This simplifies the CTC task for the reason that model is taken into account as “acoustic only”, meaning that it makes prediction largely based on the phonetics sounds of the audio, slightly than any language modelling context of the spoken sentence.

Very low-frequency characters can significantly affect loss during learning by causing loss spikes via erroneous targets. By default, the CTC tokenizer created on this blog post would add them to the vocabulary even when their frequency is negligible in comparison with more frequent characters. We will treat these characters as “errors” within the dataset annotation, in order that they will be faraway from the vocabulary, and easily classified as "[UNK]" during training.

It’s due to this fact absolutely obligatory to recheck the tokenizer vocabulary and take away all low-frequency characters, in much the identical way as we removed Latin characters when creating the tokenizer.

Note that the Common Voice dataset is especially susceptible to such “flawed” characters, for instance characters from other languages (阪).



Training-related suggestions

Average duration seen by each CTC token: through experimentation, we found the perfect ratio of duration seen per CTC token is 10 to 35 ms. In other words, to have the ability to learn and predict appropriately, the duration of the acoustic information a CTC token must see needs to be neither too low nor too high. In actual fact, it should roughly correspond to a fraction of the time it takes us humans to pronounce a phoneme.

One of my training runs had a loss curve initially going nicely downwards, as expected, but in some unspecified time in the future it began to blow up. I noticed that I had been using a basic checkpoint with no architecture changes, and that every CTC token was seeing a bit of the signal for 30 to 60 ms. Adding an convolutional adapter layer to sub-sample the encoder hidden-states along the time dimension was enough to scale back the signal chunk sampling to the specified duration and to forestall such a loss curve.

Under-training: My colleagues quickly noticed when my training runs that the models was severely under-trained, something that would have been spotted by the loss curve, which looks prefer it was stopped in the midst of a steep descent. This identified other issues as well, notably the loss curve not being smooth enough, an indication of flawed hyper-parameters settings.

Listed below are just a few ways to unravel under-training in our case:

  • the warm-up rate is likely to be too high, causing the educational rate to drop too quickly. A method to solve this might be keep the warmup ratio to five to fifteen% and scale up the variety of epochs. The nice and cozy-up steps are essential to regularly bring the brand new language-model head weights into alignment with the pre-trained model.
  • Loss curve lack of smoothness will be played around because of AdamW‘s β2 beta_2

Related posts and extra links are listed here:



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x