Fine-tuning the ASR model
In this section, we’ll cover a step-by-step guide on fine-tuning Whisper for speech recognition on the Common Voice 13 dataset. We’ll use the ‘small’ version of the model and a relatively lightweight dataset, enabling you to run fine-tuning fairly quickly on any 16GB+ GPU with low disk space requirements, such as the 16GB T4 GPU provided in the Google Colab free tier.
Should you have a smaller GPU or encounter memory issues during training, you can follow the suggestions provided for reducing memory usage. Conversely, should you have access to a larger GPU, you can amend the training arguments to maximise your throughput. Thus, this guide is accessible regardless of your GPU specifications!
Likewise, this guide outlines how to fine-tune the Whisper model for the Dhivehi language. However, the steps covered here generalise to any language in the Common Voice dataset, and more generally to any ASR dataset on the Hugging Face Hub. You can tweak the code to quickly switch to a language of your choice and fine-tune a Whisper model in your native tongue 🌍
Right! Now that’s out the way, let’s get started and kick-off our fine-tuning pipeline!
Prepare Environment
We strongly advise you to upload model checkpoints directly the Hugging Face Hub while training. The Hub provides:
- Integrated version control: you can be sure that no model checkpoint is lost during training.
- Tensorboard logs: track important metrics over the course of training.
- Model cards: document what a model does and its intended use cases.
- Community: an easy way to share and collaborate with the community! 🤗
Linking the notebook to the Hub is straightforward - it simply requires entering your Hub authentication token when prompted. Find your Hub authentication token here and enter it when prompted:
from huggingface_hub import notebook_login
notebook_login()
Output:
Login successful Your token has been saved to /root/.huggingface/token
Load Dataset
Common Voice 13 contains approximately ten hours of labelled Dhivehi data, three of which is held-out test data. This is extremely little data for fine-tuning, so we’ll be relying on leveraging the extensive multilingual ASR knowledge acquired by Whisper during pre-training for the low-resource Dhivehi language.
Using 🤗 Datasets, downloading and preparing data is extremely simple. We can download and prepare the Common Voice 13
splits in just one line of code. Since Dhivehi is very low-resource, we’ll combine the train
and validation
splits
to give approximately seven hours of training data. We’ll use the three hours of test
data as our held-out test set:
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset(
"mozilla-foundation/common_voice_13_0", "dv", split="train+validation"
)
common_voice["test"] = load_dataset(
"mozilla-foundation/common_voice_13_0", "dv", split="test"
)
print(common_voice)
Output:
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
num_rows: 4904
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
num_rows: 2212
})
})
Most ASR datasets only provide input audio samples (audio
) and the corresponding transcribed text (sentence
).
Common Voice contains additional metadata information, such as accent
and locale
, which we can disregard for ASR.
Keeping the notebook as general as possible, we only consider the input audio and transcribed text for fine-tuning,
discarding the additional metadata information:
common_voice = common_voice.select_columns(["audio", "sentence"])
Feature Extractor, Tokenizer and Processor
The ASR pipeline can be de-composed into three stages:
- The feature extractor which pre-processes the raw audio-inputs to log-mel spectrograms
- The model which performs the sequence-to-sequence mapping
- The tokenizer which post-processes the predicted tokens to text
In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, called WhisperFeatureExtractor and WhisperTokenizer respectively. To make our lives simple, these two objects are wrapped under a single class, called the WhisperProcessor. We can call the WhisperProcessor to perform both the audio pre-processing and the text token post-processing. In doing so, we only need to keep track of two objects during training: the processor and the model.
When performing multilingual fine-tuning, we need to set the "language"
and "task"
when instantiating the processor.
The "language"
should be set to the source audio language, and the task to "transcribe"
for speech recognition or
"translate"
for speech translation. These arguments modify the behaviour of the tokenizer, and should be set correctly
to ensure the target labels are encoded properly.
We can see all possible languages supported by Whisper by importing the list of languages:
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
TO_LANGUAGE_CODE
If you scroll through this list, you’ll notice that many languages are present, but Dhivehi is one of few that is not! This means that Whisper was not pre-trained on Dhivehi. However, this doesn’t mean that we can’t fine tune Whisper on it. In doing so, we’ll be teaching Whisper a new language, one that the pre-trained checkpoint does not support. That’s pretty cool, right!
When you fine-tune it on a new language, Whisper does a good job at leveraging its knowledge of the other 96 languages it’s pre-trained on. Largely speaking, all modern languages will be linguistically similar to at least one of the 96 languages Whisper already knows, so we’ll fall under this paradigm of cross-lingual knowledge representation.
What we need to do to fine-tune Whisper on a new language is find the language most similar that Whisper was
pre-trained on. The Wikipedia article for Dhivehi states that Dhivehi is closely related to the Sinhalese language of Sri Lanka.
If we check the language codes again, we can see that Sinhalese is present in the Whisper language set,
so we can safely set our language argument to "sinhalese"
.
Right! We’ll load our processor from the pre-trained checkpoint, setting the language to "sinhalese"
and task to "transcribe"
as explained above:
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained(
"openai/whisper-small", language="sinhalese", task="transcribe"
)
It’s worth reiterating that in most circumstances, you’ll find that the language you want to fine-tune on is in the set of
pre-training languages, in which case you can simply set the language directly as your source audio language! Note that
both of these arguments should be omitted for English-only fine-tuning, where there is only one option for the language
("English"
) and task ("transcribe"
).
Pre-Process the Data
Let’s have a look at the dataset features. Pay particular attention to the "audio"
column - this details the sampling
rate of our audio inputs:
common_voice["train"].features
Output:
{'audio': Audio(sampling_rate=48000, mono=True, decode=True, id=None),
'sentence': Value(dtype='string', id=None)}
Since our input audio is sampled at 48kHz, we need to downsample it to 16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.
We’ll set the audio inputs to the correct sampling rate using dataset’s cast_column
method. This operation does not change the audio in-place, but rather signals to datasets to resample audio samples
on-the-fly when they are loaded:
from datasets import Audio
sampling_rate = processor.feature_extractor.sampling_rate
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=sampling_rate))
Now we can write a function to prepare our data ready for the model:
- We load and resample the audio data on a sample-by-sample basis by calling
sample["audio"]
. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly. - We use the feature extractor to compute the log-mel spectrogram input features from our 1-dimensional audio array.
- We encode the transcriptions to label ids through the use of the tokenizer.
def prepare_dataset(example):
audio = example["audio"]
example = processor(
audio=audio["array"],
sampling_rate=audio["sampling_rate"],
text=example["sentence"],
)
# compute input length of audio sample in seconds
example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
return example
We can apply the data preparation function to all of our training examples using 🤗 Datasets’ .map
method. We’ll
remove the columns from the raw training data (the audio and text), leaving just the columns returned by the
prepare_dataset
function:
common_voice = common_voice.map(
prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1
)
Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by
the Whisper feature-extractor which could affect the stability of training. We define a function that returns True
for
samples that are less than 30s, and False
for those that are longer:
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
We apply our filter function to all samples of our training dataset through 🤗 Datasets’ .filter
method:
common_voice["train"] = common_voice["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
Let’s check how much training data we removed through this filtering step:
common_voice["train"]
Output
Dataset({
features: ['input_features', 'labels', 'input_length'],
num_rows: 4904
})
Alright! In this case we actually have the same number of samples as before, so there were no samples longer than 30s. This might not be the case if you switch languages, so it’s best to keep this filter step in-place for robustness. With that, we have our data fully prepared for training! Let’s continue and take a look at how we can use this data to fine-tune Whisper.
Training and Evaluation
Now that we’ve prepared our data, we’re ready to dive into the training pipeline. The 🤗 Trainer will do much of the heavy lifting for us. All we have to do is:
Define a data collator: the data collator takes our pre-processed data and prepares PyTorch tensors ready for the model.
Evaluation metrics: during evaluation, we want to evaluate the model using the word error rate (WER) metric. We need to define a
compute_metrics
function that handles this computation.Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly for training.
Define the training arguments: these will be used by the 🤗 Trainer in constructing the training schedule.
Once we’ve fine-tuned the model, we will evaluate it on the test data to verify that we have correctly trained it to transcribe speech in Dhivehi.
Define a Data Collator
The data collator for a sequence-to-sequence speech model is unique in the sense that it treats the input_features
and labels
independently: the input_features
must be handled by the feature extractor and the labels
by the tokenizer.
The input_features
are already padded to 30s and converted to a log-Mel spectrogram of fixed dimension, so all we
have to do is convert them to batched PyTorch tensors. We do this using the feature extractor’s .pad
method with
return_tensors=pt
. Note that no additional padding is applied here since the inputs are of fixed dimension, the
input_features
are simply converted to PyTorch tensors.
On the other hand, the labels
are un-padded. We first pad the sequences to the maximum length in the batch using
the tokenizer’s .pad
method. The padding tokens are then replaced by -100
so that these tokens are not taken
into account when computing the loss. We then cut the start of transcript token from the beginning of the label sequence
as we append it later during training.
We can leverage the WhisperProcessor
we defined earlier to perform both the feature extractor and the tokenizer operations:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [
{"input_features": feature["input_features"][0]} for feature in features
]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
We can now initialise the data collator we’ve just defined:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
Onwards!
Evaluation Metrics
Next, we define the evaluation metric we’ll use on our evaluation set. We’ll use the Word Error Rate (WER) metric introduced in the section on Evaluation, the ‘de-facto’ metric for assessing ASR systems.
We’ll load the WER metric from 🤗 Evaluate:
import evaluate
metric = evaluate.load("wer")
We then simply have to define a function that takes our model predictions and returns the WER metric. This function, called
compute_metrics
, first replaces -100
with the pad_token_id
in the label_ids
(undoing the step we applied in the
data collator to ignore padded tokens correctly in the loss). It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels. Here, we have the option of evaluating with the ‘normalised’
transcriptions and predictions, which have punctuation and casing removed. We recommend you follow this to benefit
from the WER improvement obtained by normalising the transcriptions.
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
normalizer = BasicTextNormalizer()
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# compute orthographic wer
wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
# compute normalised WER
pred_str_norm = [normalizer(pred) for pred in pred_str]
label_str_norm = [normalizer(label) for label in label_str]
# filtering step to only evaluate the samples that correspond to non-zero references:
pred_str_norm = [
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
]
label_str_norm = [
label_str_norm[i]
for i in range(len(label_str_norm))
if len(label_str_norm[i]) > 0
]
wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)
return {"wer_ortho": wer_ortho, "wer": wer}
Load a Pre-Trained Checkpoint
Now let’s load the pre-trained Whisper small checkpoint. Again, this is trivial through use of 🤗 Transformers!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
We’ll set use_cache
to False
for training since we’re using gradient checkpointing
and the two are incompatible. We’ll also override two generation arguments to control the behaviour of the model during inference:
we’ll force the language and task tokens during generation by setting the language
and task
arguments, and also re-enable
cache for generation to speed-up inference time:
from functools import partial
# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False
# set language and task for generation and re-enable cache
model.generate = partial(
model.generate, language="sinhalese", task="transcribe", use_cache=True
)
Define the Training Configuration
In the final step, we define all the parameters related to training. Here, we set the number of training steps to 500. This is enough steps to see a big WER improvement compared to the pre-trained Whisper model, while ensuring that fine-tuning can be run in approximately 45 minutes on a Google Colab free tier. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments docs.
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-dv", # name on the HF Hub
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
lr_scheduler_type="constant_with_warmup",
warmup_steps=50,
max_steps=500, # increase to 4000 if you have your own GPU or a Colab paid plan
gradient_checkpointing=True,
fp16=True,
fp16_full_eval=True,
evaluation_strategy="steps",
per_device_eval_batch_size=16,
predict_with_generate=True,
generation_max_length=225,
save_steps=500,
eval_steps=500,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
We can forward the training arguments to the 🤗 Trainer along with our model, dataset, data collator and compute_metrics
function:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor,
)
And with that, we’re ready to start training!
Training
To launch training, simply execute:
trainer.train()
Training will take approximately 45 minutes depending on your GPU or the one allocated to the Google Colab. Depending on
your GPU, it is possible that you will encounter a CUDA "out-of-memory"
error when you start training. In this case,
you can reduce the per_device_train_batch_size
incrementally by factors of 2 and employ gradient_accumulation_steps
to compensate.
Output:
Training Loss | Epoch | Step | Validation Loss | Wer Ortho | Wer |
---|---|---|---|---|---|
0.136 | 1.63 | 500 | 0.1727 | 63.8972 | 14.0661 |
Our final WER is 14.1% - not bad for seven hours of training data and just 500 training steps! That amounts to a 112% improvement versus the pre-trained model! That means we’ve taken a model that previously had no knowledge about Dhivehi, and fine-tuned it to recognise Dhivehi speech with adequate accuracy in under one hour 🤯
The big question is how this compares to other ASR systems. For that, we can view the autoevaluate leaderboard, a leaderboard that categorises models by language and dataset, and subsequently ranks them according to their WER.
Looking at the leaderboard, we see that our model trained for 500 steps convincingly beats the pre-trained Whisper Small checkpoint that we evaluated in the previous section. Nice job 👏
We see that there are a few checkpoints that do better than the one we trained. The beauty of the Hugging Face Hub is that
it’s a collaborative platform - if we don’t have the time or resources to perform a longer training run ourselves, we
can load a checkpoint that someone else in the community has trained and been kind enough to share (making sure to thank them for it!).
You’ll be able to load these checkpoints in exactly the same way as the pre-trained ones using the pipeline
class as we
did previously! So there’s nothing stopping you cherry-picking the best model on the leaderboard to use for your task!
We can automatically submit our checkpoint to the leaderboard when we push the training results to the Hub - we simply have to set the appropriate key-word arguments (kwargs). You can change these values to match your dataset, language and model name accordingly:
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_13_0",
"dataset": "Common Voice 13", # a 'pretty' name for the training dataset
"language": "dv",
"model_name": "Whisper Small Dv - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
}
The training results can now be uploaded to the Hub. To do so, execute the push_to_hub
command:
trainer.push_to_hub(**kwargs)
This will save the training logs and model weights under "your-username/the-name-you-picked"
. For this example, check
out the upload at sanchit-gandhi/whisper-small-dv
.
While the fine-tuned model yields satisfactory results on the Common Voice 13 Dhivehi test data, it is by no means optimal. The purpose of this guide is to demonstrate how to fine-tune an ASR model using the 🤗 Trainer for multilingual speech recognition.
If you have access to your own GPU or are subscribed to a Google Colab paid plan, you can increase max_steps
to 4000 steps
to improve the WER further by training for more steps. Training for 4000 steps will take approximately 3-5 hours depending
on your GPU and yield WER results approximately 3% lower than training for 500 steps. If you decide to train for 4000 steps,
we also recommend changing the learning rate scheduler to a linear schedule (set lr_scheduler_type="linear"
), as this will
yield an additional performance boost over long training runs.
The results could likely be improved further by optimising the training hyperparameters, such as learning rate and
dropout, and using a larger pre-trained checkpoint (medium
or large
). We leave this as an exercise to the reader.
Sharing Your Model
You can now share this model with anyone using the link on the Hub. They can load it with the identifier "your-username/the-name-you-picked"
directly into the pipeline()
object. For instance, to load the fine-tuned checkpoint “sanchit-gandhi/whisper-small-dv”:
from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="sanchit-gandhi/whisper-small-dv")
Conclusion
In this section, we covered a step-by-step guide on fine-tuning the Whisper model for speech recognition 🤗 Datasets,
Transformers and the Hugging Face Hub. We first loaded the Dhivehi subset of the Common Voice 13 dataset and pre-processed
it by computing log-mel spectrograms and tokenising the text. We then defined a data collator, evaluation metric and
training arguments, before using the 🤗 Trainer to train and evaluate our model. We finished by uploading the fine-tuned
model to the Hugging Face Hub, and showcased how to share and use it with the pipeline()
class.
If you followed through to this point, you should now have a fine-tuned checkpoint for speech recognition, well done! 🥳 Even more importantly, you’re equipped with all the tools you need to fine-tune the Whisper model on any speech recognition dataset or domain. So what are you waiting for! Pick one of the datasets covered in the section Choosing a Dataset or select a dataset of your own, and see whether you can get state-of-the-art performance! The leaderboard is waiting for you…