pere commited on
Commit
5d3b348
1 Parent(s): dea1939

updated training code

Browse files
=0.16.4, ADDED
File without changes
=2.0.0, ADDED
File without changes
distil_whisper/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/__init__.cpython-310.pyc and b/distil_whisper/__pycache__/__init__.cpython-310.pyc differ
 
distil_whisper/__pycache__/layers.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/layers.cpython-310.pyc and b/distil_whisper/__pycache__/layers.cpython-310.pyc differ
 
distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc and b/distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc differ
 
distil_whisper/__pycache__/partitioner.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/partitioner.cpython-310.pyc and b/distil_whisper/__pycache__/partitioner.cpython-310.pyc differ
 
distil_whisper/__pycache__/pipeline.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/pipeline.cpython-310.pyc and b/distil_whisper/__pycache__/pipeline.cpython-310.pyc differ
 
distil_whisper/__pycache__/train_state.cpython-310.pyc CHANGED
Binary files a/distil_whisper/__pycache__/train_state.cpython-310.pyc and b/distil_whisper/__pycache__/train_state.cpython-310.pyc differ
 
run_distillation.py CHANGED
@@ -558,7 +558,7 @@ def get_data_loader(
558
  dataset: IterableDataset,
559
  batch_size: int,
560
  data_collator: FlaxDataCollatorSpeechSeq2SeqWithPadding,
561
- shuffle: bool = True,
562
  drop_last: bool = True,
563
  dataloader_num_workers: int = 0,
564
  skip_batches: int = 0,
 
558
  dataset: IterableDataset,
559
  batch_size: int,
560
  data_collator: FlaxDataCollatorSpeechSeq2SeqWithPadding,
561
+ shuffle: bool = False,
562
  drop_last: bool = True,
563
  dataloader_num_workers: int = 0,
564
  skip_batches: int = 0,
run_distillation_debug.py ADDED
@@ -0,0 +1,2162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training the Whisper model for sequence to sequence speech recognition via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import re
24
+ import shutil
25
+ import string
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from functools import partial
30
+ from pathlib import Path
31
+ from typing import Any, Callable, Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import evaluate
35
+ import flax
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import numpy as np
39
+ import optax
40
+ import torch
41
+ import transformers
42
+ from datasets import (
43
+ DatasetDict,
44
+ IterableDataset,
45
+ IterableDatasetDict,
46
+ concatenate_datasets,
47
+ interleave_datasets,
48
+ load_dataset,
49
+ )
50
+ from flax import jax_utils, traverse_util
51
+ from flax.jax_utils import pad_shard_unpad, unreplicate
52
+ from flax.serialization import from_bytes, to_bytes
53
+ from flax.training import train_state
54
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
55
+ from huggingface_hub import Repository, create_repo
56
+ from jax.experimental.compilation_cache import compilation_cache as cc
57
+ from optax._src import linear_algebra
58
+ from torch.utils.data import DataLoader
59
+ from torchdata.datapipes.iter import IterableWrapper
60
+ from tqdm import tqdm
61
+ from transformers import (
62
+ AddedToken,
63
+ HfArgumentParser,
64
+ Seq2SeqTrainingArguments,
65
+ WhisperConfig,
66
+ WhisperFeatureExtractor,
67
+ WhisperProcessor,
68
+ WhisperTokenizerFast,
69
+ is_tensorboard_available,
70
+ is_wandb_available,
71
+ set_seed,
72
+ )
73
+ from transformers.file_utils import get_full_repo_name
74
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
75
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer,EnglishTextNormalizer
76
+ from transformers.utils import check_min_version, send_example_telemetry
77
+ from transformers.utils.versions import require_version
78
+
79
+ from distil_whisper import FlaxWhisperForConditionalGeneration
80
+
81
+
82
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
83
+ check_min_version("4.27.0.dev0")
84
+
85
+ require_version(
86
+ "datasets>=1.18.0",
87
+ "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt",
88
+ )
89
+
90
+ logger = logging.getLogger(__name__)
91
+
92
+
93
+ @flax.struct.dataclass
94
+ class ModelArguments:
95
+ """
96
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
97
+ """
98
+
99
+ model_name_or_path: str = field(
100
+ metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")}
101
+ )
102
+ teacher_model_name_or_path: str = field(
103
+ metadata={"help": ("Path to pretrained teacher model or model identifier from huggingface.co/models")}
104
+ )
105
+ config_name: Optional[str] = field(
106
+ default=None,
107
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
108
+ )
109
+ tokenizer_name: Optional[str] = field(
110
+ default=None,
111
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
112
+ )
113
+ feature_extractor_name: Optional[str] = field(
114
+ default=None,
115
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
116
+ )
117
+ cache_dir: Optional[str] = field(
118
+ default=None,
119
+ metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")},
120
+ )
121
+ use_fast_tokenizer: bool = field(
122
+ default=True,
123
+ metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")},
124
+ )
125
+ model_revision: str = field(
126
+ default="main",
127
+ metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")},
128
+ )
129
+ subfolder: str = field(
130
+ default="",
131
+ metadata={
132
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
133
+ "specify the folder name here."
134
+ },
135
+ )
136
+ use_auth_token: bool = field(
137
+ default=False,
138
+ metadata={
139
+ "help": (
140
+ "Will use the token generated when running `transformers-cli login`"
141
+ " (necessary to use this script with private models)."
142
+ )
143
+ },
144
+ )
145
+ dtype: Optional[str] = field(
146
+ default="float32",
147
+ metadata={
148
+ "help": (
149
+ "Floating-point format in which the model weights should be initialized"
150
+ " and trained. Choose one of `[float32, float16, bfloat16]`."
151
+ )
152
+ },
153
+ )
154
+ load_with_scan_weights: bool = field(
155
+ default=False,
156
+ metadata={
157
+ "help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned "
158
+ "weights, defaults to False for non-scan (unrolled) weights."
159
+ },
160
+ )
161
+ activation_dropout: float = field(
162
+ default=0.0,
163
+ metadata={"help": "The dropout ratio for activations inside the fully connected layer."},
164
+ )
165
+ attention_dropout: float = field(
166
+ default=0.0,
167
+ metadata={"help": "The dropout ratio for the attention probabilities."},
168
+ )
169
+ dropout: float = field(
170
+ default=0.0,
171
+ metadata={
172
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
173
+ },
174
+ )
175
+
176
+
177
+ @flax.struct.dataclass
178
+ class DataTrainingArguments:
179
+ """
180
+ Arguments pertaining to what data we are going to input our model for training and eval.
181
+ """
182
+
183
+ train_dataset_name: str = field(
184
+ default=None,
185
+ metadata={
186
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
187
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
188
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
189
+ },
190
+ )
191
+ train_dataset_config_name: Optional[str] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
195
+ "multiple datasets by separating dataset configs by a '+' symbol."
196
+ },
197
+ )
198
+ train_dataset_samples: str = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "Number of samples in the training data. Load and combine "
202
+ "multiple datasets by separating dataset samples by a '+' symbol."
203
+ },
204
+ )
205
+ eval_dataset_name: str = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
209
+ },
210
+ )
211
+ eval_dataset_config_name: Optional[str] = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
215
+ },
216
+ )
217
+ dataset_cache_dir: Optional[str] = field(
218
+ default=None,
219
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
220
+ )
221
+ overwrite_cache: bool = field(
222
+ default=False,
223
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
224
+ )
225
+ preprocessing_num_workers: Optional[int] = field(
226
+ default=None,
227
+ metadata={"help": "The number of processes to use for the preprocessing."},
228
+ )
229
+ max_train_samples: Optional[int] = field(
230
+ default=None,
231
+ metadata={
232
+ "help": (
233
+ "For debugging purposes or quicker training, truncate the number of"
234
+ " training examples to this value if set."
235
+ )
236
+ },
237
+ )
238
+ max_eval_samples: Optional[int] = field(
239
+ default=None,
240
+ metadata={
241
+ "help": (
242
+ "For debugging purposes or quicker training, truncate the number of"
243
+ " evaluation examples to this value if set."
244
+ )
245
+ },
246
+ )
247
+ audio_column_name: str = field(
248
+ default="audio",
249
+ metadata={"help": ("The name of the dataset column containing the audio data. Defaults to 'audio'")},
250
+ )
251
+ train_text_column_name: str = field(
252
+ default="whisper_transcript",
253
+ metadata={
254
+ "help": (
255
+ "The name of the dataset column containing the text data. Defaults to"
256
+ " 'whisper_transcript'which is the pseudo-labelled Whisper"
257
+ " transcription data."
258
+ )
259
+ },
260
+ )
261
+ eval_text_column_name: str = field(
262
+ default="text",
263
+ metadata={
264
+ "help": (
265
+ "The name of the dataset column containing the text data. Defaults to"
266
+ " 'text', which is the original text data"
267
+ )
268
+ },
269
+ )
270
+ max_duration_in_seconds: float = field(
271
+ default=30.0,
272
+ metadata={"help": ("Filter audio files that are longer than `max_duration_in_seconds` seconds")},
273
+ )
274
+ min_duration_in_seconds: float = field(
275
+ default=0.0,
276
+ metadata={"help": ("Filter audio files that are shorter than `min_duration_in_seconds` seconds")},
277
+ )
278
+ max_label_length: int = field(
279
+ default=128,
280
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
281
+ )
282
+ pad_target_to_multiple_of: Optional[int] = field(
283
+ default=None,
284
+ metadata={
285
+ "help": (
286
+ "If set will pad the target sequence to a multiple of the provided"
287
+ " value. This is important to avoid triggering recompilations on TPU."
288
+ " If unspecified, will default to padding the targets to max length."
289
+ )
290
+ },
291
+ )
292
+ preprocessing_only: bool = field(
293
+ default=False,
294
+ metadata={
295
+ "help": (
296
+ "Whether to only do data preprocessing and skip training. This is"
297
+ " especially useful when data preprocessing errors out in distributed"
298
+ " training due to timeout. In this case, one should run the"
299
+ " preprocessing in a non-distributed setup with"
300
+ " `preprocessing_only=True` so that the cached datasets can"
301
+ " consequently be loaded in distributed training"
302
+ )
303
+ },
304
+ )
305
+ train_split_name: str = field(
306
+ default="train",
307
+ metadata={
308
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
309
+ },
310
+ )
311
+ eval_split_name: str = field(
312
+ default="validation",
313
+ metadata={
314
+ "help": (
315
+ "The name of the evaluation data set split to use (via the datasets"
316
+ " library). Defaults to 'validation'"
317
+ )
318
+ },
319
+ )
320
+ wandb_project: str = field(
321
+ default="distil-whisper",
322
+ metadata={"help": "The name of the wandb project."},
323
+ )
324
+ wandb_name: str = field(
325
+ default=None,
326
+ metadata={"help": "The name of the wandb run."},
327
+ )
328
+ wandb_job_type: str = field(
329
+ default="distil-whisper",
330
+ metadata={"help": "The name of the wandb job type."},
331
+ )
332
+ wandb_dir: str = field(
333
+ default=None,
334
+ metadata={"help": "The absolute path to save the wandb logs."},
335
+ )
336
+ save_code_to_wandb: bool = field(
337
+ default=False,
338
+ metadata={
339
+ "help": (
340
+ "Whether to save main script to wandb. This is valuable for improving"
341
+ " experiment reproducibility and to diff code across experiments in"
342
+ " the UI."
343
+ )
344
+ },
345
+ )
346
+ streaming: bool = field(
347
+ default=True,
348
+ metadata={"help": "Whether to use Datasets' streaming mode to load and the data."},
349
+ )
350
+ wer_threshold: float = field(
351
+ default=None,
352
+ metadata={
353
+ "help": "Filter training data with Whisper transcriptions that have greater than `wer_threshold` "
354
+ "WER with the normalised transcriptions."
355
+ },
356
+ )
357
+ prefetch_size: int = field(
358
+ default=0,
359
+ metadata={"help": "Number of samples to pre-fetch if using an iterable dataset."},
360
+ )
361
+ timestamp_probability: float = field(
362
+ default=0.5, metadata={"help": "Probability for training on timestamped tokens if the data contains it."}
363
+ )
364
+ return_timestamps: bool = field(
365
+ default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
366
+ )
367
+ round_timestamps: bool = field(
368
+ default=False,
369
+ metadata={
370
+ "help": "Whether or not to round the timestamp tokens to the nearest tenth of a second."
371
+ "By default, Whisper predicts timestamps to the nearest hundredth of a second."
372
+ "Reducing the timestamp precision to one tenth of a second simplifies the timestamp"
373
+ "prediction task, at the expense of timestamp granularity."
374
+ },
375
+ )
376
+
377
+
378
+ @dataclass
379
+ class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
380
+ use_scan: Optional[bool] = field(
381
+ default=True,
382
+ metadata={
383
+ "help": (
384
+ "Whether or not to use `scan_with_axes` over the encoder and decoder blocks. Using scan results "
385
+ "in faster compile times and more efficient memory use during training, since all of the layers "
386
+ "in the encoder/decoder are stacked, and we perform a lax.scan over the stacked block to index "
387
+ "each layer. However, it results in slower inference time due to the overhead of stacking the "
388
+ "layers this way. Thus, we **always** default to disabling scan for the inference step."
389
+ )
390
+ },
391
+ )
392
+ freeze_encoder: Optional[bool] = field(
393
+ default=False,
394
+ metadata={
395
+ "help": (
396
+ "Whether to freeze the entire encoder model. Only recommended when the entire encoder has been "
397
+ "copied from the teacher model."
398
+ )
399
+ },
400
+ )
401
+ temperature: Optional[float] = field(
402
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
403
+ )
404
+ kl_weight: Optional[float] = field(
405
+ default=1.0,
406
+ metadata={
407
+ "help": (
408
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
409
+ "computed between the teacher-student hidden states and attentions."
410
+ )
411
+ },
412
+ )
413
+ mse_weight: Optional[float] = field(
414
+ default=0.0,
415
+ metadata={
416
+ "help": (
417
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
418
+ "computed between the teacher-student hidden states and attentions."
419
+ )
420
+ },
421
+ )
422
+ precision: Optional[str] = field(
423
+ default="half_mixed",
424
+ metadata={
425
+ "help": (
426
+ "Precision with which run training, Can be one of `full`, `half_mixed` or `full_mixed`, the latter two"
427
+ "of which enable *mixed-precision* training. **Note that this only specifies the dtype of the computation "
428
+ "and optimizer state. It does not influence the dtype of model parameters.** An explanation of the three "
429
+ "settings is provided below:"
430
+ " 1. Full precision: forward pass, backward pass and optimiser states all in float32."
431
+ " 2. Half mixed precision: forward pass in bfloat16, backward pass and optimiser states in float32. This "
432
+ " corresponds to setting the dtype argument to bfloat16 when instantiating the model."
433
+ " 3. Full mixed precision: forward pass, backward pass and optimiser states all in bfloat16. The dtype "
434
+ " argument is set to bfloat16 for the forward pass, and the gradients computed with respect to the bfloat16 "
435
+ " parameters in the backward pass (giving bfloat16 gradients). The new optimiser states and parameter "
436
+ " updates are computed in float32 by upcasting the bfloat16 gradients and optimiser states to float32 "
437
+ " prior to the optimiser update step. The optimiser states are returned in float32 (but not saved to "
438
+ " memory) and then downcasted to bfloat16 (saved to memory) for the subsequent train step."
439
+ "For further details, refer to https://github.com/deepmind/optax/discussions/336"
440
+ )
441
+ },
442
+ )
443
+ compilation_cache: Optional[bool] = field(
444
+ default=False,
445
+ metadata={
446
+ "help": (
447
+ "Whether to enable the JAX (experimental) compilation cache. The compilation step is *cached* the "
448
+ "first time it is run. Successive compilation steps for the same function utilise the cache to reduce"
449
+ "the compilation time."
450
+ )
451
+ },
452
+ )
453
+ save_train_state: Optional[bool] = field(
454
+ default=False,
455
+ metadata={
456
+ "help": "Whether or not to save the Flax Train State on each `save_steps` steps. Required if you intend"
457
+ "to resume training from partial training runs. If False, only the model weights will be saved."
458
+ "If True, both the model weights and Flax Train state will be saved."
459
+ },
460
+ )
461
+
462
+
463
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
464
+ """
465
+ Shift label ids one token to the right.
466
+ """
467
+ shifted_label_ids = np.zeros_like(label_ids)
468
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
469
+ shifted_label_ids[:, 0] = decoder_start_token_id
470
+
471
+ return shifted_label_ids
472
+
473
+
474
+ @flax.struct.dataclass
475
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
476
+ """
477
+ Data collator that will dynamically pad the inputs received.
478
+ Args:
479
+ processor ([`Wav2Vec2Processor`])
480
+ The processor used for proccessing the data.
481
+ decoder_start_token_id (:obj: `int`)
482
+ The start-of-sequence token id of the decoder.
483
+ decoder_prev_token_id (:obj: `int`)
484
+ The start-of-prompt token id of the decoder
485
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
486
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
487
+ among:
488
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
489
+ sequence if provided).
490
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
491
+ maximum acceptable input length for the model if that argument is not provided.
492
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
493
+ different lengths).
494
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
495
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
496
+ See above for details.
497
+ max_target_length (:obj:`int`, `optional`):
498
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
499
+ """
500
+
501
+ processor: Any
502
+ decoder_start_token_id: int
503
+ decoder_prev_token_id: int
504
+ input_padding: Union[bool, str] = "max_length"
505
+ target_padding: Union[bool, str] = "max_length"
506
+ max_target_length: Optional[int] = None
507
+
508
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
509
+ # split inputs and labels since they have to be of different lengths and need
510
+ # different padding methods
511
+ model_input_name = self.processor.model_input_names[0]
512
+
513
+ # dataloader returns a list of features which we convert to a dict
514
+ input_features = {model_input_name: [feature[model_input_name] for feature in features]}
515
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
516
+
517
+ # reformat list to dict and set to pytorch format
518
+ batch = self.processor.feature_extractor.pad(
519
+ input_features,
520
+ padding=self.input_padding,
521
+ return_tensors="np",
522
+ )
523
+
524
+ labels_batch = self.processor.tokenizer.pad(
525
+ label_features,
526
+ max_length=self.max_target_length,
527
+ padding=self.target_padding,
528
+ return_tensors="np",
529
+ )
530
+
531
+ # if bos token is appended in previous tokenization step,
532
+ # cut bos token here as it's append later anyways
533
+ labels = labels_batch["input_ids"]
534
+ if set(np.unique(labels[:, 0])).issubset({self.decoder_start_token_id, self.decoder_prev_token_id}):
535
+ decoder_input_ids = labels[:, :-1]
536
+ labels = labels[:, 1:]
537
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
538
+ else:
539
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
540
+
541
+ # replace padding with -100 to ignore correctly when computing the loss
542
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
543
+ labels = labels.filled(fill_value=-100)
544
+
545
+ # replace initial prompt tokens with -100 to ignore correctly when computing the loss
546
+ bos_index = np.argmax(labels == self.decoder_start_token_id, axis=1)
547
+ prompt_mask = np.arange(labels.shape[1]) < bos_index[:, None]
548
+ labels = np.where(prompt_mask, -100, labels)
549
+
550
+ batch["labels"] = labels
551
+ batch["decoder_input_ids"] = decoder_input_ids
552
+
553
+ return batch
554
+
555
+
556
+ def get_data_loader(
557
+ seed: int,
558
+ dataset: IterableDataset,
559
+ batch_size: int,
560
+ data_collator: FlaxDataCollatorSpeechSeq2SeqWithPadding,
561
+ shuffle: bool = False,
562
+ drop_last: bool = True,
563
+ dataloader_num_workers: int = 0,
564
+ skip_batches: int = 0,
565
+ pin_memory: bool = True,
566
+ prefetch_size: int = 0,
567
+ ) -> DataLoader:
568
+ """
569
+ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
570
+ and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
571
+
572
+ Args:
573
+ seed (int): Numpy seed for generating pseudo random numbers. Used if shuffling the dataset.
574
+ dataset (IterableDataset): streaming dataset from which to load the data.
575
+ batch_size (int): how many samples per batch to load.
576
+ data_collator (FlaxDataCollatorSpeechSeq2SeqWithPadding, optional): merges a list of samples to form a
577
+ mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
578
+ shuffle (bool, optional): set to `True` to have the batches reshuffled.
579
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
580
+ if the dataset size is not divisible by the batch size. If ``False`` and
581
+ the size of dataset is not divisible by the batch size, then the last batch
582
+ will be smaller. (default: ``False``)
583
+ dataloader_num_workers (int, optional): how many subprocesses to use for data
584
+ loading. ``0`` means that the data will be loaded in the main process.
585
+ (default: ``0``)
586
+ skip_batches (int, optional): Efficiently skip the first `skip_batches`.
587
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
588
+ into device/CUDA pinned memory before returning them. If your data elements
589
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
590
+ see the example below.
591
+
592
+ """
593
+ if shuffle:
594
+ dataset = dataset.shuffle(seed)
595
+
596
+ if skip_batches > 0:
597
+ dataset = dataset.skip(skip_batches * batch_size)
598
+
599
+ if prefetch_size > 0:
600
+ dataset = IterableWrapper(dataset)
601
+ dataset = dataset.prefetch(prefetch_size)
602
+
603
+ data_loader = DataLoader(
604
+ dataset,
605
+ batch_size=batch_size,
606
+ drop_last=drop_last,
607
+ pin_memory=pin_memory,
608
+ collate_fn=data_collator,
609
+ num_workers=dataloader_num_workers,
610
+ )
611
+
612
+ return data_loader
613
+
614
+
615
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
616
+ ordering_and_checkpoint_path = []
617
+
618
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
619
+
620
+ for path in glob_checkpoints:
621
+ if use_mtime:
622
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
623
+ else:
624
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
625
+ if regex_match is not None and regex_match.groups() is not None:
626
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
627
+
628
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
629
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
630
+ return checkpoints_sorted
631
+
632
+
633
+ def rotate_checkpoints(
634
+ save_total_limit=None, use_mtime=False, output_dir=None, checkpoint_prefix="checkpoint"
635
+ ) -> None:
636
+ if save_total_limit is None or save_total_limit <= 0:
637
+ return
638
+
639
+ # Check if we should delete older checkpoint(s)
640
+ checkpoints_sorted = sorted_checkpoints(
641
+ use_mtime=use_mtime, output_dir=output_dir, checkpoint_prefix=checkpoint_prefix
642
+ )
643
+ if len(checkpoints_sorted) <= save_total_limit:
644
+ return
645
+
646
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
647
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
648
+ for checkpoint in checkpoints_to_be_deleted:
649
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
650
+ shutil.rmtree(checkpoint, ignore_errors=True)
651
+
652
+
653
+ def to_fp32(t):
654
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
655
+
656
+
657
+ def to_bf16(t):
658
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
659
+
660
+
661
+ class TrainState(train_state.TrainState):
662
+ dropout_rng: jnp.ndarray
663
+ max_grad_norm: float
664
+
665
+ def apply_gradients(self, *, grads, to_dtype: to_fp32, **kwargs):
666
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the
667
+ gradients by the maximum grad norm.
668
+
669
+ Note that internally this function calls `.tx.update()` followed by a call
670
+ to `optax.apply_updates()` to update `params` and `opt_state`.
671
+
672
+ Args:
673
+ grads: Gradients that have the same pytree structure as `.params`.
674
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
675
+
676
+ Returns:
677
+ An updated instance of `self` with `step` incremented by one, `params`
678
+ and `opt_state` updated by applying `grads`, and additional attributes
679
+ replaced as specified by `kwargs`.
680
+ """
681
+ # clip gradients by global l2 norm
682
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
683
+ g_norm = linear_algebra.global_norm(grads)
684
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
685
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
686
+
687
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
688
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
689
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
690
+
691
+ new_params = optax.apply_updates(self.params, updates)
692
+
693
+ return self.replace(
694
+ step=self.step + 1,
695
+ params=new_params,
696
+ opt_state=to_dtype(new_opt_state),
697
+ **kwargs,
698
+ )
699
+
700
+ @classmethod
701
+ def create(cls, *, apply_fn, params, tx, to_dtype: to_fp32, **kwargs):
702
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
703
+ # downcast optimizer state to bf16 if mixed-precision training
704
+ opt_state = tx.init(to_dtype(params))
705
+ return cls(
706
+ step=0,
707
+ apply_fn=apply_fn,
708
+ params=params,
709
+ tx=tx,
710
+ opt_state=opt_state,
711
+ **kwargs,
712
+ )
713
+
714
+ def replicate(self):
715
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
716
+
717
+ def unreplicate(self):
718
+ return jax_utils.unreplicate(self)
719
+
720
+ def save_state(self, output_dir, save_total_limit=None, checkpoint_prefix="checkpoint"):
721
+ step = int(jax.device_get(unreplicate(self.step)))
722
+ serialized_state = to_bytes(self.unreplicate())
723
+
724
+ output_file = Path(os.path.join(output_dir, f"{checkpoint_prefix}-{step}", "train_state.msgpack"))
725
+ output_file.parent.mkdir(exist_ok=True, parents=True)
726
+
727
+ with output_file.open("wb") as f:
728
+ f.write(serialized_state)
729
+
730
+ logger.info(f"Flax train state saved in {output_file}")
731
+ rotate_checkpoints(
732
+ save_total_limit=save_total_limit, output_dir=output_dir, checkpoint_prefix=checkpoint_prefix
733
+ )
734
+
735
+
736
+ def save_hf_weights(
737
+ student_state: TrainState,
738
+ student_model: FlaxWhisperForConditionalGeneration,
739
+ processor: WhisperProcessor,
740
+ output_dir: str,
741
+ cur_step: int,
742
+ total_train_steps: int,
743
+ use_scan: bool = True,
744
+ checkpoint_prefix: str = "checkpoint",
745
+ ) -> None:
746
+ # always disable scan in the params / model so that we can load from PyTorch directly - this is a no-op if we're not using scan for training
747
+ student_state_params = unreplicate(student_state.params)
748
+ student_state_params = student_model.convert_scan_to_unroll(student_state_params)
749
+ student_params = jax.device_get(student_state_params)
750
+ student_model.disable_scan()
751
+
752
+ if cur_step != total_train_steps:
753
+ output_dir = os.path.join(output_dir, f"{checkpoint_prefix}-{cur_step}")
754
+ os.makedirs(output_dir, exist_ok=True)
755
+
756
+ student_model.save_pretrained(output_dir, params=student_params)
757
+ processor.save_pretrained(output_dir)
758
+
759
+ # re-enable scan only if required for training
760
+ if use_scan:
761
+ student_model.enable_scan()
762
+
763
+
764
+ def write_train_metric(summary_writer, train_metrics, train_time, step, logging_steps):
765
+ summary_writer.scalar("train/time", train_time, step)
766
+
767
+ train_metrics = get_metrics(train_metrics)
768
+ for key, vals in train_metrics.items():
769
+ steps_arr = np.arange(0, step, logging_steps)[-len(vals) :]
770
+ tag = f"train/{key}"
771
+ for i, val in enumerate(vals):
772
+ summary_writer.scalar(tag, val, steps_arr[i])
773
+
774
+
775
+ def write_eval_metric(summary_writer, eval_metrics, step, prefix="eval"):
776
+ for metric_name, value in eval_metrics.items():
777
+ summary_writer.scalar(f"{prefix}/{metric_name}", value, step)
778
+
779
+
780
+ def write_wandb_metric(wandb_logger, metrics, train_time, step, epoch, prefix="train"):
781
+ log_metrics = {}
782
+ for k, v in metrics.items():
783
+ log_metrics[f"{prefix}/{k}"] = v
784
+ log_metrics[f"{prefix}/time"] = train_time
785
+ log_metrics[f"{prefix}/epoch"] = epoch
786
+ wandb_logger.log(log_metrics, step)
787
+
788
+
789
+ def write_wandb_pred(
790
+ wandb_logger, pred_str, label_str, norm_pred_str, norm_label_str, cur_step, prefix="eval", num_lines=200000
791
+ ):
792
+ # pretty name for current step: step 50000 -> step 50k
793
+ cur_step_pretty = f"{int(cur_step // 1000)}k" if cur_step > 1000 else cur_step
794
+ # convert str data to a wandb compatible format
795
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
796
+ # log as a table with the appropriate headers
797
+ wandb_logger.log(
798
+ {
799
+ f"predictions/{prefix.replace('/', '-')}-step-{cur_step_pretty}": wandb_logger.Table(
800
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"], data=str_data[:num_lines]
801
+ )
802
+ },
803
+ cur_step,
804
+ )
805
+ # log incorrect normalised predictions
806
+ str_data = np.asarray(str_data)
807
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
808
+ # log as a table with the appropriate headers
809
+ wandb_logger.log(
810
+ {
811
+ f"incorrect_predictions/{prefix.replace('/', '-')}-step-{cur_step_pretty}": wandb_logger.Table(
812
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"], data=str_data_incorrect[:num_lines]
813
+ )
814
+ },
815
+ cur_step,
816
+ )
817
+
818
+
819
+ def create_learning_rate_fn(
820
+ num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float
821
+ ) -> Callable[[int], jnp.array]:
822
+ """Returns a linear warmup, linear_decay learning rate function."""
823
+ lr_scheduler_types = ("linear", "constant_with_warmup")
824
+
825
+ if lr_scheduler_type not in lr_scheduler_types:
826
+ raise ValueError(
827
+ f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
828
+ )
829
+
830
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
831
+ decay_fn = optax.linear_schedule(
832
+ init_value=learning_rate,
833
+ end_value=0 if lr_scheduler_type == "linear" else learning_rate,
834
+ transition_steps=num_train_steps - num_warmup_steps,
835
+ )
836
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
837
+ return schedule_fn
838
+
839
+
840
+ def convert_dataset_str_to_list(
841
+ dataset_names,
842
+ dataset_config_names,
843
+ splits=None,
844
+ text_column_names=None,
845
+ dataset_samples=None,
846
+ default_split="train",
847
+ ):
848
+ if isinstance(dataset_names, str):
849
+ dataset_names = dataset_names.split("+")
850
+
851
+ # we assume that all the datasets we're using derive from the distil-whisper org on the Hub - prepend the org name if necessary
852
+ for i in range(len(dataset_names)):
853
+ ds_name = dataset_names[i]
854
+ dataset_names[i] = f"distil-whisper/{ds_name}" if "/" not in ds_name else ds_name
855
+
856
+ dataset_config_names = dataset_config_names.split("+")
857
+ splits = splits.split("+") if splits is not None else None
858
+ text_column_names = text_column_names.split("+") if text_column_names is not None else None
859
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
860
+
861
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
862
+ if len(dataset_names) != len(dataset_config_names):
863
+ raise ValueError(
864
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
865
+ f" {len(dataset_config_names)} configs."
866
+ )
867
+
868
+ if splits is not None and len(splits) != len(dataset_names):
869
+ raise ValueError(
870
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
871
+ )
872
+
873
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
874
+ raise ValueError(
875
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
876
+ f" {len(text_column_names)} text column names."
877
+ )
878
+
879
+ if dataset_samples is not None:
880
+ if len(dataset_samples) != len(dataset_names):
881
+ raise ValueError(
882
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
883
+ f"{len(dataset_samples)} samples."
884
+ )
885
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
886
+ else:
887
+ dataset_samples = [None] * len(dataset_names)
888
+
889
+ text_column_names = (
890
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
891
+ )
892
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
893
+
894
+ dataset_names_dict = []
895
+ for i, ds_name in enumerate(dataset_names):
896
+ dataset_names_dict.append(
897
+ {
898
+ "name": ds_name,
899
+ "config": dataset_config_names[i],
900
+ "split": splits[i],
901
+ "text_column_name": text_column_names[i],
902
+ "samples": dataset_samples[i],
903
+ }
904
+ )
905
+ return dataset_names_dict
906
+
907
+
908
+ def load_multiple_datasets(
909
+ dataset_names: Union[List, str],
910
+ dataset_config_names: Union[List, str],
911
+ splits: Optional[Union[List, str]] = None,
912
+ text_column_names: Optional[List] = None,
913
+ sampling_rate: Optional[int] = 16000,
914
+ stopping_strategy: Optional[str] = "first_exhausted",
915
+ dataset_samples: Optional[Union[List, np.array]] = None,
916
+ streaming: bool = True,
917
+ seed: int = None,
918
+ **kwargs,
919
+ ) -> IterableDataset:
920
+ dataset_names_dict = convert_dataset_str_to_list(
921
+ dataset_names, dataset_config_names, splits, text_column_names, dataset_samples
922
+ )
923
+
924
+ if dataset_samples is not None:
925
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
926
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
927
+ else:
928
+ probabilities = None
929
+
930
+ if len(dataset_names_dict) == 1:
931
+ dataset_dict = dataset_names_dict[0]
932
+ # we have a single dataset so just return it as is
933
+ return load_dataset(
934
+ dataset_dict["name"],
935
+ dataset_dict["config"],
936
+ split=dataset_dict["split"],
937
+ streaming=streaming,
938
+ **kwargs,
939
+ )
940
+
941
+ all_datasets = []
942
+ # iterate over the datasets we want to interleave
943
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
944
+ dataset = load_dataset(
945
+ dataset_dict["name"],
946
+ dataset_dict["config"],
947
+ split=dataset_dict["split"],
948
+ streaming=streaming,
949
+ **kwargs,
950
+ )
951
+ # resample to specified sampling rate
952
+ dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
953
+ dataset = dataset.remove_columns(
954
+ set(dataset.features.keys()) - {"audio", dataset_dict["text_column_name"], "whisper_transcript"}
955
+ )
956
+ all_datasets.append(dataset)
957
+
958
+ if streaming:
959
+ interleaved_dataset = interleave_datasets(
960
+ all_datasets,
961
+ stopping_strategy=stopping_strategy,
962
+ probabilities=probabilities,
963
+ seed=seed,
964
+ )
965
+ else:
966
+ interleaved_dataset = concatenate_datasets(all_datasets)
967
+
968
+ return interleaved_dataset
969
+
970
+
971
+ def get_layers_to_supervise(student_layers: int, teacher_layers: int) -> dict:
972
+ """Helper function to map the student layer i to the teacher layer j whose output we'd like them to emulate. Used
973
+ for MSE loss terms in distillation (hidden-states and activations). Student layers are paired with teacher layers
974
+ in equal increments, e.g. for a 12-layer model distilled to a 3-layer model, student layer 0 emulates teacher layer
975
+ 3 (such that it behaves like the first 4 teacher layers), student layer 1 emulates teacher layer 7, and student layer
976
+ 2 emulates teacher layer 11. This mapping is summarised by the dictionary: {0: 3, 1: 7, 2: 11}, which is precisely
977
+ the output of this function for the arguments (student_layers=3, teacher_layers=12)."""
978
+ layer_intervals = np.linspace(teacher_layers // student_layers - 1, teacher_layers - 1, student_layers, dtype=int)
979
+ layer_intervals[-1] = teacher_layers - 1
980
+ layer_map = {}
981
+
982
+ for student_layer, teacher_layer in enumerate(layer_intervals):
983
+ layer_map[student_layer] = teacher_layer
984
+
985
+ return layer_map
986
+
987
+
988
+ class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
989
+ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
990
+ """
991
+ Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
992
+ computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
993
+ in transformers, and matches to within 1e-5 abs tolerance.
994
+ """
995
+ waveform = torch.from_numpy(waveform).type(torch.float32)
996
+
997
+ window = torch.hann_window(self.n_fft)
998
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
999
+ magnitudes = stft[..., :-1].abs() ** 2
1000
+
1001
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
1002
+ mel_spec = mel_filters.T @ magnitudes
1003
+
1004
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
1005
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
1006
+ log_spec = (log_spec + 4.0) / 4.0
1007
+ return log_spec.numpy()
1008
+
1009
+
1010
+ def main():
1011
+ # 1. Parse input arguments
1012
+ # See all possible arguments in src/transformers/training_args.py
1013
+ # or by passing the --help flag to this script.
1014
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
1015
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
1016
+
1017
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
1018
+ # If we pass only one argument to the script and it's the path to a json file,
1019
+ # let's parse it to get our arguments.
1020
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
1021
+ else:
1022
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
1023
+
1024
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
1025
+ # information sent is the one passed as arguments along with your JAX/Flax versions.
1026
+ send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
1027
+
1028
+ # 2. Define remote logging - do this early so that we get the full traceback on our remote logs
1029
+ # Enable tensorboard only on the master node
1030
+ has_tensorboard = is_tensorboard_available()
1031
+ if has_tensorboard:
1032
+ if jax.process_index() == 0:
1033
+ try:
1034
+ from flax.metrics.tensorboard import SummaryWriter
1035
+
1036
+ summary_writer = SummaryWriter(log_dir=os.path.join(Path(training_args.output_dir), "runs"))
1037
+ except ImportError as ie:
1038
+ has_tensorboard = False
1039
+ logger.warning(
1040
+ "Unable to display metrics through TensorBoard because some package" f" are not installed: {ie}"
1041
+ )
1042
+ else:
1043
+ logger.warning(
1044
+ "Unable to display metrics through TensorBoard because the package is not"
1045
+ " installed: Please run `pip install tensorboard` to enable."
1046
+ )
1047
+
1048
+ # Enable wandb only on the master node
1049
+ has_wandb = is_wandb_available()
1050
+ if has_wandb:
1051
+ import wandb as wandb_logger
1052
+
1053
+ # Set up wandb run
1054
+ if jax.process_index() == 0:
1055
+ wandb_logger.init(
1056
+ project=data_args.wandb_project,
1057
+ name=data_args.wandb_name,
1058
+ job_type=data_args.wandb_job_type,
1059
+ dir=data_args.wandb_dir,
1060
+ save_code=data_args.save_code_to_wandb,
1061
+ )
1062
+ else:
1063
+ logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
1064
+
1065
+ # 3. Setup local logging
1066
+ # Make one log on every process with the configuration for debugging.
1067
+ logging.basicConfig(
1068
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1069
+ datefmt="%m/%d/%Y %H:%M:%S",
1070
+ handlers=[logging.StreamHandler(sys.stdout)],
1071
+ )
1072
+ # Set the verbosity to info of the Transformers logger.
1073
+ # We only want one process per machine to log things on the screen.
1074
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
1075
+ if jax.process_index() == 0:
1076
+ datasets.utils.logging.set_verbosity_warning()
1077
+ transformers.utils.logging.set_verbosity_info()
1078
+ else:
1079
+ datasets.utils.logging.set_verbosity_error()
1080
+ transformers.utils.logging.set_verbosity_error()
1081
+
1082
+ logger.info("Training/evaluation parameters %s", training_args)
1083
+
1084
+ # Check the output dir is valid
1085
+ if (
1086
+ os.path.exists(training_args.output_dir)
1087
+ and os.listdir(training_args.output_dir)
1088
+ and training_args.do_train
1089
+ and not training_args.overwrite_output_dir
1090
+ ):
1091
+ raise ValueError(
1092
+ f"Output directory ({training_args.output_dir}) already exists and is not"
1093
+ " empty. Use `--overwrite_output_dir` to overcome."
1094
+ )
1095
+
1096
+ # 4. Handle the repository creation
1097
+ if training_args.push_to_hub:
1098
+ if training_args.hub_model_id is None:
1099
+ repo_name = get_full_repo_name(
1100
+ Path(training_args.output_dir).absolute().name,
1101
+ token=training_args.hub_token,
1102
+ )
1103
+ else:
1104
+ repo_name = training_args.hub_model_id
1105
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
1106
+ repo = Repository(
1107
+ training_args.output_dir,
1108
+ clone_from=repo_name,
1109
+ token=training_args.hub_token,
1110
+ )
1111
+
1112
+ if training_args.compilation_cache:
1113
+ cc.initialize_cache(os.path.join(model_args.cache_dir, "jax_cache"))
1114
+
1115
+ # 5. Load dataset
1116
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1117
+
1118
+ # set seed for determinism
1119
+ set_seed(training_args.seed)
1120
+
1121
+ if training_args.do_train:
1122
+ print("loading raw")
1123
+ raw_datasets["train"] = load_multiple_datasets(
1124
+ data_args.train_dataset_name,
1125
+ data_args.train_dataset_config_name,
1126
+ splits=data_args.train_split_name,
1127
+ streaming=data_args.streaming,
1128
+ dataset_samples=data_args.train_dataset_samples,
1129
+ seed=training_args.seed,
1130
+ cache_dir=data_args.dataset_cache_dir,
1131
+ token=True if model_args.use_auth_token else None,
1132
+ )
1133
+
1134
+ if training_args.do_eval:
1135
+ dataset_names_dict = convert_dataset_str_to_list(
1136
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
1137
+ (
1138
+ data_args.eval_dataset_config_name
1139
+ if data_args.eval_dataset_config_name
1140
+ else data_args.train_dataset_config_name
1141
+ ),
1142
+ splits=data_args.eval_split_name,
1143
+ text_column_names=data_args.eval_text_column_name,
1144
+ )
1145
+ all_eval_splits = []
1146
+ if len(dataset_names_dict) == 1:
1147
+ # load a single eval set
1148
+ dataset_dict = dataset_names_dict[0]
1149
+ all_eval_splits.append("eval")
1150
+ raw_datasets["eval"] = load_dataset(
1151
+ dataset_dict["name"],
1152
+ dataset_dict["config"],
1153
+ split=dataset_dict["split"],
1154
+ cache_dir=data_args.dataset_cache_dir,
1155
+ token=True if model_args.use_auth_token else None,
1156
+ streaming=data_args.streaming,
1157
+ )
1158
+ else:
1159
+ # load multiple eval sets
1160
+ for dataset_dict in dataset_names_dict:
1161
+ if dataset_dict["name"] == "esb/diagnostic-dataset":
1162
+ # for the ESB diagnostic dataset, the dataset name is effectively the config
1163
+ pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
1164
+ else:
1165
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
1166
+ all_eval_splits.append(pretty_name)
1167
+ raw_datasets[pretty_name] = load_dataset(
1168
+ dataset_dict["name"],
1169
+ dataset_dict["config"],
1170
+ split=dataset_dict["split"],
1171
+ cache_dir=data_args.dataset_cache_dir,
1172
+ token=True if model_args.use_auth_token else None,
1173
+ streaming=data_args.streaming,
1174
+ )
1175
+ features = raw_datasets[pretty_name].features.keys()
1176
+ if "text" not in features:
1177
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
1178
+ dataset_dict["text_column_name"], "text"
1179
+ )
1180
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
1181
+ set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
1182
+ )
1183
+
1184
+ if not training_args.do_train and not training_args.do_eval:
1185
+ raise ValueError(
1186
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
1187
+ )
1188
+
1189
+ raw_datasets_train_features = list(raw_datasets["train"].features.keys())
1190
+ print("debug 1")
1191
+
1192
+ if data_args.audio_column_name not in raw_datasets_train_features:
1193
+ raise ValueError(
1194
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
1195
+ f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
1196
+ " the correct audio column - one of"
1197
+ f" {', '.join(raw_datasets_train_features)}."
1198
+ )
1199
+
1200
+ if data_args.train_text_column_name not in raw_datasets_train_features:
1201
+ raise ValueError(
1202
+ f"--train_text_column_name {data_args.train_text_column_name} not found in dataset"
1203
+ f" '{data_args.dataset_name}'. Make sure to set `--train_text_column_name` to the"
1204
+ " correct text column - one of"
1205
+ f" {', '.join(raw_datasets_train_features)}."
1206
+ )
1207
+
1208
+ # 6. Load pretrained model, tokenizer, and feature extractor
1209
+ config = WhisperConfig.from_pretrained(
1210
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
1211
+ cache_dir=model_args.cache_dir,
1212
+ revision=model_args.model_revision,
1213
+ token=True if model_args.use_auth_token else None,
1214
+ )
1215
+ feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(
1216
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
1217
+ cache_dir=model_args.cache_dir,
1218
+ revision=model_args.model_revision,
1219
+ token=True if model_args.use_auth_token else None,
1220
+ )
1221
+ tokenizer = WhisperTokenizerFast.from_pretrained(
1222
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
1223
+ cache_dir=model_args.cache_dir,
1224
+ use_fast=model_args.use_fast_tokenizer,
1225
+ revision=model_args.model_revision,
1226
+ token=True if model_args.use_auth_token else None,
1227
+ )
1228
+ print("debug2")
1229
+ # override timestamp tokens until tokenizer issues are fixed in transformers
1230
+ timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
1231
+ tokenizer.add_tokens(timestamps)
1232
+
1233
+ config.update(
1234
+ {
1235
+ "activation_dropout": model_args.activation_dropout,
1236
+ "attention_dropout": model_args.attention_dropout,
1237
+ "dropout": model_args.dropout,
1238
+ }
1239
+ )
1240
+
1241
+ if training_args.precision == "full_mixed":
1242
+ # forward pass, backward pass and optimiser states in bf16
1243
+ dtype = jnp.bfloat16
1244
+ to_dtype = to_bf16
1245
+ elif training_args.precision == "half_mixed" or model_args.dtype == "bfloat16":
1246
+ # forward pass in bf16, backward pass and optimiser states in fp32
1247
+ dtype = jnp.bfloat16
1248
+ to_dtype = to_fp32
1249
+ else:
1250
+ if training_args.precision != "full":
1251
+ raise ValueError(
1252
+ f"`precision` should be one of: `full`, `half_mixed` or `full_mixed`, got {training_args.precision}"
1253
+ )
1254
+ # forward pass, backward pass and optimiser states in fp32
1255
+ dtype = jnp.float32
1256
+ to_dtype = to_fp32
1257
+
1258
+ student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
1259
+ model_args.model_name_or_path,
1260
+ config=config,
1261
+ dtype=dtype,
1262
+ cache_dir=model_args.cache_dir,
1263
+ revision=model_args.model_revision,
1264
+ subfolder=model_args.subfolder,
1265
+ token=True if model_args.use_auth_token else None,
1266
+ _do_init=False,
1267
+ use_scan=model_args.load_with_scan_weights,
1268
+ )
1269
+
1270
+ teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained(
1271
+ model_args.teacher_model_name_or_path,
1272
+ # config=config,
1273
+ dtype=dtype,
1274
+ cache_dir=model_args.cache_dir,
1275
+ # revision=model_args.model_revision,
1276
+ token=True if model_args.use_auth_token else None,
1277
+ _do_init=False,
1278
+ )
1279
+ print("debug 3")
1280
+ if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
1281
+ raise ValueError(
1282
+ f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
1283
+ f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
1284
+ f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
1285
+ )
1286
+
1287
+ # enable scan / gradient checkpointing if necessary
1288
+ if training_args.use_scan:
1289
+ student_model.enable_scan() # to enable scan in the nn.Module
1290
+ student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan
1291
+
1292
+ teacher_model.enable_scan() # faster compile time (even though we don't train the teacher)
1293
+ teacher_params = teacher_model.convert_unroll_to_scan(teacher_params)
1294
+
1295
+ if training_args.gradient_checkpointing:
1296
+ student_model.enable_gradient_checkpointing() # to enable checkpointing in the nn.Module, there is no change to the params structure
1297
+ teacher_model.enable_gradient_checkpointing()
1298
+ print("debug 4")
1299
+ if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
1300
+ # We need to set the language and task ids for previously multilingual checkpoints - for now we hardcode this to Norwegian
1301
+ tokenizer.set_prefix_tokens(language="Norwegian", task="transcribe", predict_timestamps=False)
1302
+ student_model.generation_config.update(
1303
+ **{
1304
+ "language": "<|no|>",
1305
+ "task": "transcribe",
1306
+ }
1307
+ )
1308
+ print("debug 5")
1309
+ # 7. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1310
+ # so we just need to set the correct target sampling rate.
1311
+ raw_datasets = raw_datasets.cast_column(
1312
+ data_args.audio_column_name,
1313
+ datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
1314
+ )
1315
+
1316
+ # 8. Preprocessing the datasets.
1317
+ # We need to read the audio files as arrays and tokenize the targets.
1318
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1319
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1320
+ max_label_length = (
1321
+ data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1322
+ )
1323
+ audio_column_name = data_args.audio_column_name
1324
+ num_workers = data_args.preprocessing_num_workers
1325
+ dataloader_num_workers = training_args.dataloader_num_workers
1326
+ dataloader_prefetch_size = data_args.prefetch_size
1327
+ train_text_column_name = data_args.train_text_column_name
1328
+ eval_text_column_name = "text"
1329
+ model_input_name = feature_extractor.model_input_names[0]
1330
+ normalizer = BasicTextNormalizer(tokenizer.english_spelling_normalizer)
1331
+ wer_threshold = data_args.wer_threshold
1332
+ round_timestamps = data_args.round_timestamps
1333
+ print("debug 6")
1334
+ if training_args.do_train and data_args.max_train_samples is not None:
1335
+ raw_datasets["train"] = (
1336
+ raw_datasets["train"].take(data_args.max_train_samples)
1337
+ if data_args.streaming
1338
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1339
+ )
1340
+
1341
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1342
+ for eval_split in all_eval_splits:
1343
+ raw_datasets[eval_split] = (
1344
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1345
+ if data_args.streaming
1346
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1347
+ )
1348
+ print("debug 7")
1349
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1350
+ def is_wer_in_range(ground_truth, whisper_transcript):
1351
+ norm_ground_truth = normalizer(ground_truth)
1352
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1353
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1354
+ return False
1355
+ elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
1356
+ return True
1357
+ elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
1358
+ norm_whisper_transcript = normalizer(whisper_transcript)
1359
+ wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1360
+ return wer < wer_threshold
1361
+ else:
1362
+ # filter automatically since we cant know WER
1363
+ return False
1364
+
1365
+
1366
+ filter_by_wer_threshold = partial(
1367
+ raw_datasets["train"].filter,
1368
+ function=is_wer_in_range,
1369
+ input_columns=[eval_text_column_name, train_text_column_name],
1370
+ )
1371
+
1372
+ if wer_threshold is not None:
1373
+ raw_datasets["train"] = (
1374
+ filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1375
+ if not data_args.streaming
1376
+ else filter_by_wer_threshold()
1377
+ )
1378
+
1379
+ def has_timestamp_tokens(input_str):
1380
+ """
1381
+ Identify whether the input string contains timestamp tokens, of the form <|0.00|>, by searching for
1382
+ pairs of left and right-angle brackets.
1383
+ """
1384
+ return bool(re.search("\<[^\>]*\>", input_str))
1385
+
1386
+ def round_timestamp_tokens(input_str: str, ndigits: int = 1):
1387
+ timestamps = re.findall("\<[^\>]*\>", input_str, re.DOTALL)
1388
+ for token in timestamps:
1389
+ # extract time digits from timestamp token, e.g. <|6.24|> to 6.24
1390
+ time_digit = token[2:-2]
1391
+ # round to specified number of digits, e.g. 6.24 to 6.2
1392
+ time_digit = round(float(time_digit), ndigits=ndigits)
1393
+ # replace in original string with the same precision, e.g. <|6.24|> to <|6.20|>
1394
+ input_str = input_str.replace(token, "<|{:.2f}|>".format(time_digit))
1395
+ return input_str
1396
+
1397
+ def prepare_train_dataset(batch):
1398
+ # process audio input
1399
+ sample = batch[audio_column_name]
1400
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1401
+ batch[model_input_name] = inputs.get(model_input_name)[0]
1402
+ batch["input_length"] = len(sample["array"])
1403
+
1404
+ # process text targets
1405
+ input_str = batch[train_text_column_name]
1406
+
1407
+ # prompt & timestamp processing: for now, we only do one or the other
1408
+ if input_str.startswith("<|startoftranscript|>") or input_str.startswith("<|startofprev|>"):
1409
+ # prompted target text already has special ids added, so don't add them here
1410
+ batch["labels"] = tokenizer(input_str, add_special_tokens=False).input_ids
1411
+ return batch
1412
+
1413
+ has_timestamps = has_timestamp_tokens(input_str)
1414
+
1415
+ if has_timestamps:
1416
+ predict_timestamps = bool(np.random.binomial(1, data_args.timestamp_probability))
1417
+ if not predict_timestamps:
1418
+ # filter timestamp token ids if not part of the prediction task
1419
+ input_str = tokenizer._filter_timestamp_ids(input_str)
1420
+ elif round_timestamps:
1421
+ input_str = round_timestamp_tokens(input_str)
1422
+ else:
1423
+ predict_timestamps = False
1424
+
1425
+ tokenizer.set_prefix_tokens(language="Norwegian", task="transcribe", predict_timestamps=predict_timestamps)
1426
+ input_ids = tokenizer(input_str).input_ids
1427
+ batch["labels"] = input_ids
1428
+ return batch
1429
+
1430
+ def prepare_eval_dataset(batch):
1431
+ # process audio
1432
+ sample = batch[audio_column_name]
1433
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1434
+ # process audio length
1435
+ batch[model_input_name] = inputs.get(model_input_name)[0]
1436
+ batch["input_length"] = len(sample["array"])
1437
+
1438
+ # process targets
1439
+ input_str = batch[eval_text_column_name]
1440
+ batch["labels"] = tokenizer(input_str).input_ids
1441
+ return batch
1442
+
1443
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1444
+ if training_args.do_train:
1445
+ map_fn_train = partial(
1446
+ raw_datasets["train"].map, function=prepare_train_dataset, remove_columns=raw_datasets_train_features
1447
+ )
1448
+ vectorized_datasets["train"] = (
1449
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1450
+ if not data_args.streaming
1451
+ else map_fn_train()
1452
+ )
1453
+ if training_args.do_eval:
1454
+ for eval_split in all_eval_splits:
1455
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1456
+ map_fn_eval = partial(
1457
+ raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1458
+ )
1459
+ vectorized_datasets[eval_split] = (
1460
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1461
+ if not data_args.streaming
1462
+ else map_fn_eval()
1463
+ )
1464
+
1465
+ # filter training data with inputs longer than max_input_length
1466
+ def is_audio_in_length_range(length):
1467
+ return min_input_length < length < max_input_length
1468
+
1469
+ filter_by_audio_fn = partial(
1470
+ vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1471
+ )
1472
+ vectorized_datasets = (
1473
+ filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1474
+ if not data_args.streaming
1475
+ else filter_by_audio_fn()
1476
+ )
1477
+
1478
+ # filter training data with labels longer than max_label_length
1479
+ def is_labels_in_length_range(labels):
1480
+ return 0 < len(labels) < max_label_length
1481
+
1482
+ filter_by_labels_fn = partial(
1483
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1484
+ )
1485
+ vectorized_datasets = (
1486
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1487
+ if not data_args.streaming
1488
+ else filter_by_labels_fn()
1489
+ )
1490
+
1491
+ # for large datasets it is advised to run the preprocessing on a
1492
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1493
+ # be a timeout when running the script in distributed mode.
1494
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1495
+ # cached dataset
1496
+ if data_args.preprocessing_only:
1497
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1498
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1499
+ return
1500
+
1501
+ # 8. Load Metric
1502
+ metric = evaluate.load("wer")
1503
+ # convention is that we space all punctuation *except* apostrophes
1504
+ all_punctuation = list(string.punctuation.replace("'", ""))
1505
+ return_timestamps = data_args.return_timestamps if data_args.timestamp_probability > 0 else False
1506
+
1507
+ def compute_metrics(preds, labels):
1508
+ # replace padded labels by the padding token
1509
+ for idx in range(len(labels)):
1510
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1511
+
1512
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1513
+ # we do not want to group tokens when computing the metrics
1514
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1515
+
1516
+ # space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
1517
+ spaced_pred_str = [
1518
+ pred_str[i].replace(punctuation, f" {punctuation} ")
1519
+ for punctuation in all_punctuation
1520
+ for i in range(len(pred_str))
1521
+ ]
1522
+ spaced_label_str = [
1523
+ label_str[i].replace(punctuation, f" {punctuation} ")
1524
+ for punctuation in all_punctuation
1525
+ for i in range(len(label_str))
1526
+ ]
1527
+ wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
1528
+
1529
+ # Iterate through all predictions and labels
1530
+ for pred, label in zip(pred_str, label_str):
1531
+ # Normalize the prediction and label
1532
+ normalized_pred = normalizer(pred)
1533
+ normalized_label = normalizer(label)
1534
+
1535
+ # If either normalized string is empty after normalization, replace with "<|nospeech|>"
1536
+ if not normalized_pred.strip():
1537
+ normalized_pred = "<|nospeech|>"
1538
+ if not normalized_label.strip():
1539
+ normalized_label = "<|nospeech|>"
1540
+
1541
+ norm_pred_str.append(normalized_pred)
1542
+ norm_label_str.append(normalized_label)
1543
+
1544
+ # Replace original strings with "<|nocaptions|>" where necessary for consistency
1545
+ pred_str = [pred if len(pred.strip()) > 0 else "<|nospeech|>" for pred in pred_str]
1546
+ label_str = [label if len(label.strip()) > 0 else "<|nospeech|>" for label in label_str]
1547
+
1548
+ # Compute WER using all entries, including those with "<|nocaptions|>"
1549
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1550
+ return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1551
+
1552
+
1553
+ # 9. Save feature extractor, tokenizer, config and generation config
1554
+ feature_extractor.save_pretrained(training_args.output_dir)
1555
+ tokenizer.save_pretrained(training_args.output_dir)
1556
+ config.save_pretrained(training_args.output_dir)
1557
+ student_model.generation_config.save_pretrained(
1558
+ training_args.output_dir
1559
+ ) # generation config stays bound to model to make it easy to jit
1560
+
1561
+ processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1562
+
1563
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1564
+ processor=processor,
1565
+ decoder_start_token_id=student_model.config.decoder_start_token_id, # <|startoftranscript|>
1566
+ decoder_prev_token_id=tokenizer.all_special_ids[-3], # <|startofprev|>
1567
+ input_padding="longest",
1568
+ target_padding="max_length",
1569
+ max_target_length=max_label_length,
1570
+ )
1571
+
1572
+ # Initialize our training
1573
+ rng = jax.random.PRNGKey(training_args.seed)
1574
+ rng, dropout_rng = jax.random.split(rng)
1575
+
1576
+ # Store some constants
1577
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1578
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1579
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1580
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
1581
+
1582
+ if not data_args.streaming and training_args.max_steps < 0:
1583
+ num_epochs = int(training_args.num_train_epochs)
1584
+ steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
1585
+ total_train_steps = steps_per_epoch * num_epochs
1586
+ elif training_args.max_steps > 0:
1587
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1588
+ total_train_steps = int(training_args.max_steps)
1589
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1590
+ num_epochs = sys.maxsize
1591
+ steps_per_epoch = total_train_steps
1592
+ else:
1593
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1594
+
1595
+ if training_args.eval_steps is None:
1596
+ logger.info(
1597
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1598
+ )
1599
+ eval_steps = steps_per_epoch
1600
+ else:
1601
+ eval_steps = training_args.eval_steps
1602
+
1603
+ # Create learning rate schedule
1604
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1605
+ total_train_steps * gradient_accumulation_steps,
1606
+ training_args.lr_scheduler_type,
1607
+ training_args.warmup_steps * gradient_accumulation_steps,
1608
+ training_args.learning_rate,
1609
+ )
1610
+
1611
+ # We use Optax's "masking" functionality to not apply weight decay
1612
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1613
+ # mask boolean with the same structure as the parameters.
1614
+ # The mask is True for parameters that should be decayed.
1615
+ def decay_mask_fn(params):
1616
+ flat_params = traverse_util.flatten_dict(params)
1617
+ # find out all LayerNorm parameters
1618
+ layer_norm_candidates = [
1619
+ "layer_norm",
1620
+ "self_attn_layer_norm",
1621
+ "final_layer_norm",
1622
+ "encoder_attn_layer_norm",
1623
+ ]
1624
+ layer_norm_named_params = {
1625
+ layer[-2:]
1626
+ for layer_norm_name in layer_norm_candidates
1627
+ for layer in flat_params.keys()
1628
+ if layer_norm_name in "".join(layer).lower()
1629
+ }
1630
+ flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
1631
+ return traverse_util.unflatten_dict(flat_mask)
1632
+
1633
+ # create adam optimizer
1634
+ adamw = optax.adamw(
1635
+ learning_rate=linear_decay_lr_schedule_fn,
1636
+ b1=training_args.adam_beta1,
1637
+ b2=training_args.adam_beta2,
1638
+ eps=training_args.adam_epsilon,
1639
+ weight_decay=training_args.weight_decay,
1640
+ mask=decay_mask_fn,
1641
+ )
1642
+
1643
+ if gradient_accumulation_steps > 1:
1644
+ # accumulate gradients and apply once every k steps
1645
+ adamw = optax.MultiSteps(adamw, every_k_schedule=gradient_accumulation_steps)
1646
+
1647
+ share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
1648
+ encoder_layer_mapping = get_layers_to_supervise(
1649
+ student_model.config.encoder_layers, teacher_model.config.encoder_layers
1650
+ )
1651
+ decoder_layer_mapping = get_layers_to_supervise(
1652
+ student_model.config.decoder_layers, teacher_model.config.decoder_layers
1653
+ )
1654
+
1655
+ # Setup train state
1656
+ student_state = TrainState.create(
1657
+ apply_fn=student_model.decode if share_hidden_states else student_model.__call__,
1658
+ params=student_params,
1659
+ tx=adamw,
1660
+ to_dtype=to_dtype,
1661
+ dropout_rng=dropout_rng,
1662
+ max_grad_norm=training_args.max_grad_norm,
1663
+ )
1664
+
1665
+ if training_args.resume_from_checkpoint is not None:
1666
+ if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
1667
+ logger.info(
1668
+ f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
1669
+ "this behavior, omit the resume_from_checkpoint argument."
1670
+ )
1671
+ with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
1672
+ student_state = from_bytes(student_state, f.read())
1673
+ else:
1674
+ logger.warning(
1675
+ f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
1676
+ f"you pass the path to a folder with a valid checkpoint for your model."
1677
+ )
1678
+
1679
+ def cross_entropy_loss(logits, labels):
1680
+ vocab_size = logits.shape[-1]
1681
+ # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
1682
+ onehot_targets = to_dtype(onehot(labels, vocab_size))
1683
+ loss = optax.softmax_cross_entropy(logits, onehot_targets)
1684
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
1685
+ padding = labels >= 0
1686
+ loss = loss * padding
1687
+ loss = loss.sum()
1688
+ num_labels = padding.sum()
1689
+ return loss, num_labels
1690
+
1691
+ # temperature smoothed kl-divergence
1692
+ def kl_divergence(target_distribution, log_predicted_distribution, labels, eps=1e-20):
1693
+ divergence = -target_distribution * (log_predicted_distribution - jnp.log(target_distribution + eps))
1694
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1695
+ padding_mask = labels >= 0
1696
+ padding_mask = jnp.expand_dims(padding_mask, axis=-1)
1697
+ divergence = (divergence * padding_mask).sum()
1698
+ return to_dtype(divergence) # respect the dtype of the backprop
1699
+
1700
+ def mean_square_error_loss(student_outputs, teacher_outputs):
1701
+ mse = dtype(0.0)
1702
+
1703
+ # tie encoder embeddings
1704
+ mse += jnp.mean(
1705
+ jnp.square(teacher_outputs.encoder_hidden_states[0] - student_outputs.encoder_hidden_states[0])
1706
+ )
1707
+
1708
+ for student_layer_id, teacher_layer_id in encoder_layer_mapping.items():
1709
+ # offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
1710
+ student_hidden_state = student_outputs.encoder_hidden_states[student_layer_id + 1]
1711
+ teacher_hidden_state = teacher_outputs.encoder_hidden_states[teacher_layer_id + 1]
1712
+ mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))
1713
+
1714
+ # student_attention = student_outputs.encoder_attentions[student_layer_id]
1715
+ # teacher_attention = teacher_outputs.encoder_attentions[teacher_layer_id]
1716
+ # mse += jnp.mean(jnp.square(student_attention - teacher_attention))
1717
+
1718
+ # tie decoder embeddings
1719
+ mse += jnp.mean(
1720
+ jnp.square(teacher_outputs.decoder_hidden_states[0] - student_outputs.decoder_hidden_states[0])
1721
+ )
1722
+
1723
+ for student_layer_id, teacher_layer_id in decoder_layer_mapping.items():
1724
+ # offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
1725
+ student_hidden_state = student_outputs.decoder_hidden_states[student_layer_id + 1]
1726
+ teacher_hidden_state = teacher_outputs.decoder_hidden_states[teacher_layer_id + 1]
1727
+ mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))
1728
+
1729
+ # student_attention = student_outputs.decoder_attentions[student_layer_id]
1730
+ # teacher_attention = teacher_outputs.decoder_attentions[teacher_layer_id]
1731
+ # mse += jnp.mean(jnp.square(student_attention - teacher_attention))
1732
+
1733
+ # student_cross_attention = student_outputs.cross_attentions[student_layer_id]
1734
+ # teacher_cross_attention = teacher_outputs.cross_attentions[teacher_layer_id]
1735
+ # mse += jnp.mean(jnp.square(student_cross_attention - teacher_cross_attention))
1736
+
1737
+ return to_dtype(mse) # respect the dtype of the backprop
1738
+
1739
+ # Define gradient update step fn
1740
+ def train_step(
1741
+ student_state,
1742
+ teacher_params,
1743
+ batch,
1744
+ freeze_encoder,
1745
+ share_hidden_states,
1746
+ temperature=2.0,
1747
+ ):
1748
+ dropout_rng, new_dropout_rng = jax.random.split(student_state.dropout_rng)
1749
+
1750
+ def compute_loss(student_params):
1751
+ labels = batch.pop("labels")
1752
+ output_hidden_states = not share_hidden_states and training_args.mse_weight > 0.0
1753
+
1754
+ teacher_outputs = teacher_model(
1755
+ **batch,
1756
+ params=teacher_params,
1757
+ freeze_encoder=True,
1758
+ output_hidden_states=output_hidden_states,
1759
+ train=False,
1760
+ )
1761
+
1762
+ if share_hidden_states:
1763
+ # if the student and teacher share the same frozen encoder then we don't have to recompute the
1764
+ # encoder hidden-states for the student model, we can just re-use from the teacher
1765
+ encoder_hidden_states = jax.lax.stop_gradient(teacher_outputs.encoder_last_hidden_state)
1766
+ encoder_outputs = FlaxBaseModelOutput(last_hidden_state=encoder_hidden_states)
1767
+
1768
+ student_outputs = student_state.apply_fn(
1769
+ decoder_input_ids=batch["decoder_input_ids"],
1770
+ encoder_outputs=encoder_outputs,
1771
+ params=student_params,
1772
+ dropout_rng=dropout_rng,
1773
+ train=True,
1774
+ )
1775
+ else:
1776
+ # do the full forward pass for the student model (encoder + decoder)
1777
+ student_outputs = student_state.apply_fn(
1778
+ **batch,
1779
+ params=student_params,
1780
+ dropout_rng=dropout_rng,
1781
+ freeze_encoder=freeze_encoder,
1782
+ output_hidden_states=output_hidden_states,
1783
+ train=True,
1784
+ )
1785
+
1786
+ # CE (data) loss
1787
+ ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)
1788
+
1789
+ # rescale by temperature to ensure gradients scale correctly
1790
+ teacher_distribution = jax.nn.softmax(teacher_outputs.logits / temperature, axis=-1)
1791
+ # ensure no information flow backwards through teacher
1792
+ teacher_distribution = jax.lax.stop_gradient(teacher_distribution)
1793
+ # log softmax of student predictions for numerical stability
1794
+ student_distribution = jax.nn.log_softmax(student_outputs.logits / temperature, axis=-1)
1795
+ # KL-divergence loss (scaled by temperature)
1796
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, labels) * temperature**2
1797
+
1798
+ # MSE loss between enc-dec hidden-states and attentions
1799
+ mse_loss = (
1800
+ mean_square_error_loss(student_outputs, teacher_outputs)
1801
+ if output_hidden_states
1802
+ else jnp.zeros_like(kl_loss)
1803
+ )
1804
+
1805
+ # use DistilBart formulation - only tune the MSE weight and take remaining HPs from DistilBERT
1806
+ ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
1807
+ loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss
1808
+
1809
+ return loss, (
1810
+ ce_loss,
1811
+ kl_loss,
1812
+ mse_loss,
1813
+ num_labels,
1814
+ )
1815
+
1816
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
1817
+ (loss, (ce_loss, kl_loss, mse_loss, num_labels)), grad = grad_fn(to_dtype(student_state.params))
1818
+
1819
+ # true loss = total loss / total samples
1820
+ loss = jax.lax.psum(loss, "batch")
1821
+ num_labels = jax.lax.psum(num_labels, "batch")
1822
+ loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
1823
+
1824
+ # true grad = total grad / total samples
1825
+ grad = jax.lax.psum(grad, "batch")
1826
+ grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
1827
+ new_state = student_state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng, to_dtype=to_dtype)
1828
+
1829
+ # CE/KL/MSE losses for logging
1830
+ ce_loss = jax.lax.psum(ce_loss, "batch")
1831
+ ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)
1832
+
1833
+ kl_loss = jax.lax.psum(kl_loss, "batch")
1834
+ kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)
1835
+
1836
+ mse_loss = jax.lax.psum(mse_loss, "batch")
1837
+ mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)
1838
+
1839
+ metrics = {
1840
+ "loss": loss,
1841
+ "learning_rate": linear_decay_lr_schedule_fn(student_state.step),
1842
+ "ce_loss": ce_loss,
1843
+ "kl_loss": kl_loss,
1844
+ "mse_loss": mse_loss,
1845
+ }
1846
+ return new_state, metrics
1847
+
1848
+ # Define eval fn
1849
+ def eval_step(student_params, teacher_params, batch):
1850
+ labels = batch.pop("labels")
1851
+ output_hidden_states = not share_hidden_states and training_args.mse_weight > 0
1852
+
1853
+ student_outputs = student_model(
1854
+ **batch,
1855
+ params=student_params,
1856
+ output_hidden_states=output_hidden_states,
1857
+ train=False,
1858
+ )
1859
+ student_distribution = jax.nn.log_softmax(student_outputs.logits, axis=-1)
1860
+ ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)
1861
+
1862
+ teacher_outputs = teacher_model(
1863
+ **batch,
1864
+ params=teacher_params,
1865
+ output_hidden_states=output_hidden_states,
1866
+ train=False,
1867
+ )
1868
+ teacher_distribution = jax.nn.softmax(teacher_outputs.logits, axis=-1)
1869
+ # temperature is always 1 for eval
1870
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, labels)
1871
+
1872
+ mse_loss = (
1873
+ mean_square_error_loss(student_outputs, teacher_outputs)
1874
+ if output_hidden_states
1875
+ else jnp.zeros_like(kl_loss)
1876
+ )
1877
+
1878
+ ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
1879
+ loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss
1880
+ # true loss = total loss / total samples
1881
+ loss = jax.lax.psum(loss, "batch")
1882
+ num_labels = jax.lax.psum(num_labels, "batch")
1883
+ loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
1884
+
1885
+ # CE/KL/MSE losses for logging
1886
+ ce_loss = jax.lax.psum(ce_loss, "batch")
1887
+ ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)
1888
+
1889
+ kl_loss = jax.lax.psum(kl_loss, "batch")
1890
+ kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)
1891
+
1892
+ mse_loss = jax.lax.psum(mse_loss, "batch")
1893
+ mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)
1894
+
1895
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss, "mse_loss": mse_loss}
1896
+ return metrics
1897
+
1898
+ # Define generation function
1899
+ num_beams = (
1900
+ training_args.generation_num_beams
1901
+ if training_args.generation_num_beams is not None
1902
+ else student_model.config.num_beams
1903
+ )
1904
+
1905
+ # forcing the language and task tokens helps the model in its generations
1906
+ gen_kwargs = {
1907
+ "max_length": max_label_length,
1908
+ "num_beams": num_beams,
1909
+ "language": "<|en|>",
1910
+ "task": "transcribe",
1911
+ "return_timestamps": return_timestamps,
1912
+ }
1913
+
1914
+ def generate_step(student_params, batch):
1915
+ output_ids = student_model.generate(
1916
+ batch[model_input_name],
1917
+ attention_mask=batch.get("attention_mask"),
1918
+ params=student_params,
1919
+ **gen_kwargs,
1920
+ )
1921
+ return output_ids.sequences
1922
+
1923
+ # Replicate the train state on each device
1924
+ student_state = student_state.replicate()
1925
+
1926
+ # Replicate the teacher params on each device
1927
+ teacher_params = jax_utils.replicate(teacher_params)
1928
+
1929
+ # Create parallel version of the train and eval step
1930
+ p_train_step = jax.pmap(
1931
+ train_step,
1932
+ "batch",
1933
+ in_axes=(0, 0, 0, None, None, None),
1934
+ donate_argnums=(0,),
1935
+ static_broadcasted_argnums=(
1936
+ 3,
1937
+ 4,
1938
+ ),
1939
+ )
1940
+ p_eval_step = jax.pmap(eval_step, "batch")
1941
+ p_generate_step = jax.pmap(generate_step, "batch")
1942
+
1943
+ logger.info("***** Running training *****")
1944
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1945
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1946
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1947
+ logger.info(
1948
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1949
+ )
1950
+ logger.info(f" Total optimization steps = {total_train_steps}")
1951
+
1952
+ # ======================== Training ================================
1953
+ train_time = 0
1954
+ train_start = time.time()
1955
+ train_metrics = []
1956
+ batches_to_skip = jax.device_get(unreplicate(student_state.step))
1957
+ cur_step = int(batches_to_skip) # will be zero if starting from scratch
1958
+ epochs_trained = batches_to_skip // steps_per_epoch
1959
+ steps_trained_progress_bar = tqdm(range(total_train_steps), desc="Train steps ... ", position=0)
1960
+ steps_trained_progress_bar.update(batches_to_skip)
1961
+ continue_training = True
1962
+ minibatch_steps = 0
1963
+ print("Debug 8")
1964
+ if batches_to_skip > 0:
1965
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1966
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1967
+ logger.info(f" Continuing training from global step {batches_to_skip}")
1968
+ print("debug 9")
1969
+ # Generate a training data loader by shuffling sampling indices from the train dataset
1970
+ train_loader = get_data_loader(
1971
+ training_args.seed,
1972
+ vectorized_datasets["train"],
1973
+ batch_size=train_batch_size,
1974
+ data_collator=data_collator,
1975
+ dataloader_num_workers=dataloader_num_workers,
1976
+ skip_batches=batches_to_skip,
1977
+ prefetch_size=dataloader_prefetch_size,
1978
+ )
1979
+ print("debug 10")
1980
+
1981
+ for epoch in range(epochs_trained, num_epochs):
1982
+ print("Debug 11")
1983
+ if hasattr(train_loader, "dataset") and isinstance(train_loader.dataset, IterableDataset):
1984
+ print("Debug 11B")
1985
+ train_loader.dataset.set_epoch(epoch)
1986
+ breakpoint()
1987
+ print("debug 12")
1988
+ for batch in train_loader:
1989
+ print("debug 13")
1990
+ minibatch_steps += 1
1991
+ update_step = minibatch_steps == gradient_accumulation_steps
1992
+
1993
+ if update_step:
1994
+ steps_trained_progress_bar.update(1)
1995
+ cur_step += 1
1996
+ minibatch_steps = 0
1997
+ print("debug 14")
1998
+ batch = shard(batch.data)
1999
+ student_state, train_metric = p_train_step(
2000
+ student_state,
2001
+ teacher_params,
2002
+ batch,
2003
+ training_args.freeze_encoder,
2004
+ share_hidden_states,
2005
+ training_args.temperature,
2006
+ )
2007
+ print("debug 15")
2008
+ if cur_step % training_args.logging_steps == 0 and update_step:
2009
+ train_metrics.append(train_metric)
2010
+ train_metric_to_write = unreplicate(train_metric)
2011
+ steps_trained_progress_bar.write(
2012
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
2013
+ f" {train_metric_to_write['loss']}, Learning Rate:"
2014
+ f" {train_metric_to_write['learning_rate']})"
2015
+ )
2016
+ print("debug 16")
2017
+ if has_wandb and jax.process_index() == 0:
2018
+ write_wandb_metric(
2019
+ wandb_logger,
2020
+ train_metric_to_write,
2021
+ train_time + time.time() - train_start,
2022
+ cur_step,
2023
+ epoch,
2024
+ prefix="train",
2025
+ )
2026
+ print("debug 17")
2027
+ # save checkpoint and weights after each save_steps and at the end of training
2028
+ if (cur_step % training_args.save_steps == 0 and update_step) or cur_step == total_train_steps:
2029
+ if jax.process_index() == 0:
2030
+ save_hf_weights(
2031
+ student_state,
2032
+ student_model,
2033
+ processor,
2034
+ training_args.output_dir,
2035
+ cur_step,
2036
+ total_train_steps,
2037
+ use_scan=training_args.use_scan,
2038
+ )
2039
+ if training_args.save_train_state:
2040
+ student_state.save_state(
2041
+ training_args.output_dir, save_total_limit=training_args.save_total_limit
2042
+ )
2043
+ if training_args.push_to_hub:
2044
+ repo.push_to_hub(
2045
+ commit_message=f"Saving train state of step {cur_step}",
2046
+ blocking=False,
2047
+ )
2048
+
2049
+ if training_args.do_eval and (
2050
+ (cur_step % eval_steps == 0 and update_step) or cur_step == total_train_steps
2051
+ ):
2052
+ train_time += time.time() - train_start
2053
+ # ======================== Evaluating ==============================
2054
+ for eval_split in all_eval_splits:
2055
+ eval_metrics = []
2056
+ eval_preds = []
2057
+ eval_labels = []
2058
+ eval_start = time.time()
2059
+
2060
+ eval_loader = get_data_loader(
2061
+ training_args.seed,
2062
+ vectorized_datasets[eval_split],
2063
+ batch_size=eval_batch_size,
2064
+ data_collator=data_collator,
2065
+ shuffle=False,
2066
+ drop_last=False,
2067
+ dataloader_num_workers=dataloader_num_workers,
2068
+ )
2069
+ for batch in tqdm(eval_loader, desc=f"Evaluating {eval_split}...", position=2):
2070
+ # Model forward
2071
+ labels = batch["labels"]
2072
+
2073
+ metrics = pad_shard_unpad(
2074
+ p_eval_step,
2075
+ static_argnums=(
2076
+ 0,
2077
+ 1,
2078
+ ),
2079
+ static_return=True,
2080
+ )(
2081
+ student_state.params,
2082
+ teacher_params,
2083
+ batch.data,
2084
+ min_device_batch=per_device_eval_batch_size,
2085
+ )
2086
+ eval_metrics.append(metrics)
2087
+
2088
+ # generation
2089
+ if training_args.predict_with_generate:
2090
+ generated_ids = pad_shard_unpad(p_generate_step)(
2091
+ student_state.params, batch.data, min_device_batch=per_device_eval_batch_size
2092
+ )
2093
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
2094
+ eval_labels.extend(labels)
2095
+
2096
+ eval_time = time.time() - eval_start
2097
+
2098
+ # normalize eval metrics
2099
+ eval_metrics = get_metrics(eval_metrics)
2100
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
2101
+
2102
+ # compute WER metric
2103
+ wer_desc = ""
2104
+ if training_args.predict_with_generate:
2105
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
2106
+ eval_preds, eval_labels
2107
+ )
2108
+ eval_metrics.update(wer_metric)
2109
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
2110
+
2111
+ # Print metrics and update progress bar
2112
+ steps_trained_progress_bar.write(
2113
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
2114
+ f" {wer_desc})"
2115
+ )
2116
+
2117
+ if has_tensorboard and jax.process_index() == 0:
2118
+ write_eval_metric(
2119
+ summary_writer,
2120
+ eval_metrics,
2121
+ cur_step,
2122
+ prefix=eval_split,
2123
+ )
2124
+
2125
+ if has_wandb and jax.process_index() == 0:
2126
+ write_wandb_metric(wandb_logger, eval_metrics, eval_time, cur_step, epoch, prefix=eval_split)
2127
+ if training_args.predict_with_generate:
2128
+ write_wandb_pred(
2129
+ wandb_logger,
2130
+ pred_str,
2131
+ label_str,
2132
+ norm_pred_str,
2133
+ norm_label_str,
2134
+ cur_step,
2135
+ prefix=eval_split,
2136
+ )
2137
+
2138
+ if has_tensorboard and jax.process_index() == 0:
2139
+ # we'll only log to tensorboard every eval steps
2140
+ write_train_metric(
2141
+ summary_writer,
2142
+ train_metrics,
2143
+ train_time,
2144
+ cur_step,
2145
+ training_args.logging_steps,
2146
+ )
2147
+
2148
+ # flush the train metrics
2149
+ train_start = time.time()
2150
+ train_metrics = []
2151
+
2152
+ # break condition
2153
+ if cur_step == total_train_steps:
2154
+ continue_training = False
2155
+ break
2156
+
2157
+ if not continue_training:
2158
+ break
2159
+
2160
+
2161
+ if __name__ == "__main__":
2162
+ main()
run_large_training.sh CHANGED
@@ -1,5 +1,5 @@
1
  #!/usr/bin/env bash
2
- python3 run_distillation.py \
3
  --model_name_or_path "./nb-distil-large-init" \
4
  --teacher_model_name_or_path "NbAiLab/nb-whisper-large" \
5
  --train_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
@@ -8,14 +8,14 @@ python3 run_distillation.py \
8
  --eval_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
9
  --eval_dataset_config_name "no" \
10
  --eval_split_name "validation_norwegian_fleurs" \
11
- --eval_steps 5000 \
12
- --save_steps 5000 \
13
- --warmup_steps 500 \
14
  --learning_rate 0.0001 \
15
  --lr_scheduler_type "linear" \
16
  --logging_steps 25 \
17
  --save_total_limit 1 \
18
- --max_steps 100000 \
19
  --wer_threshold 10 \
20
  --per_device_train_batch_size 64 \
21
  --per_device_eval_batch_size 64 \
@@ -32,7 +32,7 @@ python3 run_distillation.py \
32
  --streaming \
33
  --use_auth_token \
34
  --report_to "wandb" \
35
- --wandb_project "nb-distil-whisper-large-test2" \
36
  --hub_model_id "NbAiLab/nb-distil-whisper-large-flax1-no" \
37
  --push_to_hub
38
 
 
1
  #!/usr/bin/env bash
2
+ TOKENIZERS_PARALLELISM=false python3 run_distillation.py \
3
  --model_name_or_path "./nb-distil-large-init" \
4
  --teacher_model_name_or_path "NbAiLab/nb-whisper-large" \
5
  --train_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
 
8
  --eval_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
9
  --eval_dataset_config_name "no" \
10
  --eval_split_name "validation_norwegian_fleurs" \
11
+ --eval_steps 1000 \
12
+ --save_steps 1000 \
13
+ --warmup_steps 100 \
14
  --learning_rate 0.0001 \
15
  --lr_scheduler_type "linear" \
16
  --logging_steps 25 \
17
  --save_total_limit 1 \
18
+ --max_steps 10000 \
19
  --wer_threshold 10 \
20
  --per_device_train_batch_size 64 \
21
  --per_device_eval_batch_size 64 \
 
32
  --streaming \
33
  --use_auth_token \
34
  --report_to "wandb" \
35
+ --wandb_project "nb-distil-whisper-large-test3" \
36
  --hub_model_id "NbAiLab/nb-distil-whisper-large-flax1-no" \
37
  --push_to_hub
38
 
run_large_training_debug.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ TOKENIZERS_PARALLELISM=false python3 run_distillation_debug.py \
3
+ --model_name_or_path "./nb-distil-large-init" \
4
+ --teacher_model_name_or_path "NbAiLab/nb-whisper-large" \
5
+ --train_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
6
+ --train_dataset_config_name "no" \
7
+ --train_split_name "train" \
8
+ --eval_dataset_name "NbAiLab/annotated_distil_raw_ncc_speech_v7_compact8_large" \
9
+ --eval_dataset_config_name "no" \
10
+ --eval_split_name "validation_norwegian_fleurs" \
11
+ --eval_steps 5000 \
12
+ --save_steps 5000 \
13
+ --warmup_steps 500 \
14
+ --learning_rate 0.0001 \
15
+ --lr_scheduler_type "linear" \
16
+ --logging_steps 25 \
17
+ --save_total_limit 1 \
18
+ --max_steps 100000 \
19
+ --wer_threshold 10 \
20
+ --per_device_train_batch_size 64 \
21
+ --per_device_eval_batch_size 64 \
22
+ --dataloader_num_workers 16 \
23
+ --dtype "bfloat16" \
24
+ --output_dir "./" \
25
+ --do_train \
26
+ --do_eval \
27
+ --use_scan \
28
+ --gradient_checkpointing \
29
+ --overwrite_output_dir \
30
+ --predict_with_generate \
31
+ --freeze_encoder \
32
+ --streaming \
33
+ --use_auth_token \
34
+ --report_to "wandb" \
35
+ --wandb_project "nb-distil-whisper-large-test2" \
36
+ --hub_model_id "NbAiLab/nb-distil-whisper-large-flax1-no" \
37
+ --push_to_hub
38
+