bjelkenhed commited on
Commit
c85c8df
1 Parent(s): c06c457

initial commit

Browse files
deep_run.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deepspeed run_speech_recognition_seq2seq_deepspeed.py \
2
+ --deepspeed="ds_config.json" \
3
+ --model_name_or_path="openai/whisper-large-v2" \
4
+ --dataset_name="KBLab/rixvox" \
5
+ --dataset_config_name="train" \
6
+ --evalset_name="mozilla-foundation/common_voice_11_0" \
7
+ --evalset_config_name="sv-SE" \
8
+ --language="swedish" \
9
+ --train_split_name="train" \
10
+ --eval_split_name="test" \
11
+ --model_index_name="Whisper Large Swedish Rixvox" \
12
+ --text_column_name="text" \
13
+ --eval_text_column_name="sentence" \
14
+ --max_steps="20000" \
15
+ --output_dir="./" \
16
+ --per_device_train_batch_size="64" \
17
+ --gradient_accumulation_steps="2" \
18
+ --per_device_eval_batch_size="32" \
19
+ --logging_steps="25" \
20
+ --learning_rate="1e-5" \
21
+ --seed="42" \
22
+ --warmup_steps="500" \
23
+ --evaluation_strategy="steps" \
24
+ --eval_steps="1000" \
25
+ --save_strategy="steps" \
26
+ --save_steps="1000" \
27
+ --generation_max_length="225" \
28
+ --length_column_name="input_length" \
29
+ --max_duration_in_seconds="30" \
30
+ --freeze_feature_encoder="False" \
31
+ --report_to="tensorboard" \
32
+ --metric_for_best_model="wer" \
33
+ --greater_is_better="False" \
34
+ --load_best_model_at_end \
35
+ --gradient_checkpointing \
36
+ --fp16 \
37
+ --overwrite_output_dir \
38
+ --do_train \
39
+ --do_eval \
40
+ --predict_with_generate \
41
+ --do_normalize_eval \
42
+ --streaming="True" \
43
+ --use_auth_token \
44
+ --push_to_hub
ds_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "optimizer": {
12
+ "type": "AdamW",
13
+ "params": {
14
+ "lr": "auto",
15
+ "betas": "auto",
16
+ "eps": "auto",
17
+ "weight_decay": "auto"
18
+ }
19
+ },
20
+
21
+ "scheduler": {
22
+ "type": "WarmupDecayLR",
23
+ "params": {
24
+ "last_batch_iteration": -1,
25
+ "total_num_steps": "auto",
26
+ "warmup_min_lr": "auto",
27
+ "warmup_max_lr": "auto",
28
+ "warmup_num_steps": "auto"
29
+ }
30
+ },
31
+
32
+ "zero_optimization": {
33
+ "stage": 2,
34
+ "offload_optimizer": {
35
+ "device": "cpu",
36
+ "pin_memory": true
37
+ },
38
+ "allgather_partitions": true,
39
+ "allgather_bucket_size": 2e8,
40
+ "overlap_comm": true,
41
+ "reduce_scatter": true,
42
+ "reduce_bucket_size": 2e8,
43
+ "contiguous_gradients": true
44
+ },
45
+
46
+ "gradient_accumulation_steps": "auto",
47
+ "gradient_clipping": "auto",
48
+ "train_batch_size": "auto",
49
+ "train_micro_batch_size_per_gpu": "auto"
50
+ }
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ torch>=1.7
3
+ git+https://github.com/huggingface/transformers
4
+ git+https://github.com/huggingface/datasets
5
+ librosa
6
+ jiwer
7
+ evaluate>=0.3.0
8
+ more-itertools
9
+ tensorboard
10
+ deepspeed
11
+ accelerate
12
+ pysrt
13
+ prefetch_generator
14
+ audiomentations
run_speech_recognition_seq2seq_deepspeed.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # This progam was modified by Michael Kamfonas ([email protected]) on Dec 11 2022
21
+ # - added options for drpout, gradient_checkpointing, use_cache, stopping_strategy and streaming
22
+ # - restructured it to enable both streaming and non-streaming modes
23
+ # - allows concatenation of mutiple datasets (single-string comma-separated) for interleaving
24
+ # The following params must have the same number of comma-separated (,) elements:
25
+ # dataset_name,
26
+ # dataset_config_name,
27
+ # train_split_name and eval_split_name (each element plus-separated (+) for multiple splits),
28
+ # text_column_name and audio_column_name
29
+
30
+
31
+ import logging
32
+ import os
33
+ import sys
34
+ from dataclasses import dataclass, field
35
+ from typing import Any, Dict, List, Optional, Union
36
+
37
+ import datasets
38
+ import torch
39
+ from datasets import Audio, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
40
+ from torch.utils.data import IterableDataset
41
+
42
+ import evaluate
43
+ import transformers
44
+ from transformers import (
45
+ AutoConfig,
46
+ AutoFeatureExtractor,
47
+ AutoModelForSpeechSeq2Seq,
48
+ AutoProcessor,
49
+ AutoTokenizer,
50
+ HfArgumentParser,
51
+ Seq2SeqTrainer,
52
+ Seq2SeqTrainingArguments,
53
+ TrainerCallback,
54
+ set_seed,
55
+ )
56
+ from transformers.trainer_pt_utils import IterableDatasetShard
57
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
58
+ from transformers.utils import check_min_version, send_example_telemetry
59
+ from transformers.utils.versions import require_version
60
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
61
+ from prefetch_generator import BackgroundGenerator
62
+ import torchaudio.transforms as T
63
+ import numpy as np
64
+ from transformers import WhisperTokenizer
65
+
66
+ TEXT_COL_NAME="text"
67
+ AUDIO_COL_NAME="audio"
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.25.0.dev0")
71
+
72
+ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
73
+
74
+ require_version("prefetch-generator>=1.0.3", "To fix: pip install prefetch-generator")
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+
79
+ @dataclass
80
+ class ModelArguments:
81
+ """
82
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
83
+ """
84
+
85
+ model_name_or_path: str = field(
86
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
87
+ )
88
+ config_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
93
+ )
94
+ feature_extractor_name: Optional[str] = field(
95
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ model_revision: str = field(
106
+ default="main",
107
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
+ )
109
+ use_auth_token: bool = field(
110
+ default=False,
111
+ metadata={
112
+ "help": (
113
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
114
+ "with private models)."
115
+ )
116
+ },
117
+ )
118
+ freeze_feature_encoder: bool = field(
119
+ default=True, metadata={"help": "Deprecated - Whether to freeze the feature encoder layers of the model."}
120
+ )
121
+ freeze_encoder: bool = field(
122
+ default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
123
+ )
124
+ forced_decoder_ids: List[List[int]] = field(
125
+ default=None,
126
+ metadata={
127
+ "help": (
128
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
129
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
130
+ "will always be a token of index 123."
131
+ )
132
+ },
133
+ )
134
+ suppress_tokens: List[int] = field(
135
+ default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
136
+ )
137
+ model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
138
+
139
+ ## added by Michael Kamfonas
140
+ use_cache: bool = field(
141
+ default=False, metadata={"help": "Whether to use cache."}
142
+ )
143
+
144
+ dropout: float = field(
145
+ default = 0.0, metadata = {"help": "dropout probability."}
146
+ )
147
+
148
+ attention_dropout: float = field(
149
+ default = 0.0, metadata = {"help": "attention_dropout probability."}
150
+ )
151
+
152
+
153
+
154
+ @dataclass
155
+ class DataTrainingArguments:
156
+ """
157
+ Arguments pertaining to what data we are going to input our model for training and eval.
158
+ """
159
+
160
+ dataset_name: str = field(
161
+ default=None,
162
+ metadata={"help": "The name of the dataset to use (via the datasets library)."}
163
+ )
164
+ dataset_config_name: Optional[str] = field(
165
+ default=None,
166
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
167
+ )
168
+ evalset_name: str = field(
169
+ default=None,
170
+ metadata={"help": "The name of the dataset to use (via the datasets library)."}
171
+ )
172
+ evalset_config_name: Optional[str] = field(
173
+ default=None,
174
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
175
+ )
176
+ text_column: Optional[str] = field(
177
+ default=None,
178
+ metadata={"help": "The name of the column in the datasets containing the transcription."},
179
+ )
180
+ max_train_samples: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": (
184
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
185
+ "value if set."
186
+ )
187
+ },
188
+ )
189
+ max_eval_samples: Optional[int] = field(
190
+ default=None,
191
+ metadata={
192
+ "help": (
193
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
194
+ "value if set."
195
+ )
196
+ },
197
+ )
198
+ audio_column_name: str = field(
199
+ default="audio",
200
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
201
+ )
202
+ text_column_name: str = field(
203
+ default="text",
204
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
205
+ )
206
+ eval_text_column_name: str = field(
207
+ default="text",
208
+ metadata={"help": "The name of the evalset column containing the text data. Defaults to 'text'"},
209
+ )
210
+ max_duration_in_seconds: float = field(
211
+ default=20.0,
212
+ metadata={
213
+ "help": (
214
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
215
+ " 'max_duration_in_seconds`"
216
+ )
217
+ },
218
+ )
219
+ min_duration_in_seconds: float = field(
220
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
221
+ )
222
+ train_split_name: str = field(
223
+ default="train",
224
+ metadata={
225
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
226
+ },
227
+ )
228
+ eval_split_name: str = field(
229
+ default="test",
230
+ metadata={
231
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
232
+ },
233
+ )
234
+ do_lower_case: bool = field(
235
+ default=False,
236
+ metadata={"help": "Whether the target text should be lower cased."},
237
+ )
238
+ do_remove_punctuation: bool = field(
239
+ default=False,
240
+ metadata={"help": "Whether the target text should be striped of punctuation."},
241
+ )
242
+ do_normalize_eval: bool = field(
243
+ default=True,
244
+ metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
245
+ )
246
+ language: str = field(
247
+ default=None,
248
+ metadata={
249
+ "help": (
250
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
251
+ "only. For English speech recognition, it should be set to `None`."
252
+ )
253
+ },
254
+ )
255
+ task: str = field(
256
+ default="transcribe",
257
+ metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
258
+ )
259
+ shuffle_buffer_size: Optional[int] = field(
260
+ default=500,
261
+ metadata={
262
+ "help": (
263
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
264
+ "the closer it is to real offline shuffling."
265
+ )
266
+ },
267
+ )
268
+ stopping_strategy: Optional[str] = field(
269
+ default="all_exhausted",
270
+ metadata={
271
+ "help": "Strategy used to consume interleaved data. Default = 'all_exhausted'"
272
+ }
273
+ )
274
+ streaming: bool = field(
275
+ default=True,
276
+ metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
277
+ )
278
+
279
+ @dataclass
280
+ class DataCollatorSpeechSeq2SeqWithPadding:
281
+ """
282
+ Data collator that will dynamically pad the inputs received.
283
+ Args:
284
+ processor ([`WhisperProcessor`])
285
+ The processor used for processing the data.
286
+ decoder_start_token_id (`int`)
287
+ The begin-of-sentence of the decoder.
288
+ """
289
+
290
+ processor: Any
291
+ decoder_start_token_id: int
292
+
293
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
294
+ # split inputs and labels since they have to be of different lengths and need
295
+ # different padding methods
296
+ model_input_name = self.processor.model_input_names[0]
297
+ input_features = [{model_input_name: feature[model_input_name]} for feature in features]
298
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
299
+
300
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
301
+
302
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
303
+
304
+ # replace padding with -100 to ignore loss correctly
305
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
306
+
307
+ # if bos token is appended in previous tokenization step,
308
+ # cut bos token here as it's append later anyways
309
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
310
+ labels = labels[:, 1:]
311
+
312
+ batch["labels"] = labels
313
+
314
+ return batch
315
+
316
+
317
+ def load_streaming_dataset(dataset_name, dataset_config_name, text_column_name, split="train", **kwargs):
318
+ """
319
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
320
+ each split is loaded individually and then splits combined by taking alternating examples from
321
+ each (interleaving).
322
+ """
323
+ if "+" in split:
324
+ # load multiple splits separated by the `+` symbol with streaming mode
325
+ dataset_splits = [
326
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs)
327
+ for split_name in split.split("+")
328
+ ]
329
+ # interleave multiple splits to form one dataset
330
+ interleaved_dataset = interleave_datasets(dataset_splits)
331
+ return interleaved_dataset
332
+ else:
333
+ # load a single split *with* streaming mode
334
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
335
+ dataset = dataset.rename_column(text_column_name, TEXT_COL_NAME)
336
+ return dataset
337
+
338
+ def load_multiple_streaming_datasets(
339
+ dataset_names: List,
340
+ dataset_config_names: List,
341
+ splits: Optional[List] = None,
342
+ text_column_names: Optional[List] = None,
343
+ audio_column_names: Optional[List] = None,
344
+ sampling_rate: Optional[int] = 16000,
345
+ stopping_strategy: Optional[str] = "all_exhausted",
346
+ streaming = True,
347
+ **kwargs
348
+ ):
349
+
350
+ if len(dataset_names) != len(dataset_config_names):
351
+ raise ValueError(
352
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
353
+ f" {len(dataset_config_names)} configs."
354
+ )
355
+
356
+ if splits is not None and len(splits) != len(dataset_names):
357
+ raise ValueError(
358
+ f"Ensure one train_split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
359
+ )
360
+
361
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
362
+ raise ValueError(
363
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
364
+ f" {len(text_column_names)} text column names."
365
+ )
366
+
367
+ if audio_column_names is not None and len(audio_column_names) != len(dataset_names):
368
+ raise ValueError(
369
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
370
+ f" {len(audio_column_names)} text column names."
371
+ )
372
+
373
+ splits = splits if splits is not None \
374
+ else ["train" for i in range(len(dataset_names))]
375
+
376
+ text_column_names = (
377
+ text_column_names if text_column_names is not None \
378
+ else [TEXT_COL_NAME for i in range(len(dataset_names))]
379
+ )
380
+
381
+ audio_column_names = (
382
+ audio_column_names if audio_column_names is not None \
383
+ else [AUDIO_COL_NAME for i in range(len(dataset_names))]
384
+ )
385
+
386
+ all_data_splits = []
387
+ # iterate over the datasets we want to interleave
388
+ for dset, cfgNm, splt, txtColNm, audColNm in zip(dataset_names,dataset_config_names,\
389
+ splits,text_column_names, audio_column_names):
390
+
391
+ dset_splits = [load_dataset(dset, cfgNm, split=c, streaming=streaming, **kwargs) \
392
+ for c in splt.split('+') if c != '-']
393
+
394
+ if streaming:
395
+ dset_splits = [ds if TEXT_COL_NAME in ds.features else ds.rename_column(txtColNm, TEXT_COL_NAME) \
396
+ for ds in dset_splits ]
397
+ dset_splits = [ds if AUDIO_COL_NAME in ds.features else ds.rename_column(audColNm, AUDIO_COL_NAME) \
398
+ for ds in dset_splits]
399
+
400
+ if len(dset_splits)>0 and sampling_rate != next(iter(dset_splits[0]))[AUDIO_COL_NAME]['sampling_rate']:
401
+ dset_splits = [ds.cast_column(AUDIO_COL_NAME, Audio(sampling_rate)) for ds in dset_splits]
402
+ else:
403
+
404
+ dset_splits = [ds if TEXT_COL_NAME in ds.column_names else ds.rename_column(txtColNm, TEXT_COL_NAME) \
405
+ for ds in dset_splits ]
406
+ dset_splits = [ds if AUDIO_COL_NAME in ds.column_names else ds.rename_column(audColNm, AUDIO_COL_NAME) \
407
+ for ds in dset_splits]
408
+
409
+ if len(dset_splits)>0 and sampling_rate != next(iter(dset_splits[0]))[AUDIO_COL_NAME]['sampling_rate']:
410
+ dset_splits = [ds.cast_column(AUDIO_COL_NAME, Audio(sampling_rate)) for ds in dset_splits]
411
+
412
+ cols2keep = set([AUDIO_COL_NAME, TEXT_COL_NAME])
413
+
414
+ dset_splits = [ds.remove_columns(set(ds.features.keys()) - cols2keep) for ds in dset_splits]
415
+
416
+ all_data_splits += dset_splits
417
+
418
+ return interleave_datasets(all_data_splits, stopping_strategy=stopping_strategy)
419
+
420
+ def main():
421
+ # 1. Parse input arguments
422
+ # See all possible arguments in src/transformers/training_args.py
423
+ # or by passing the --help flag to this script.
424
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
425
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
426
+
427
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
428
+ # If we pass only one argument to the script and it's the path to a json file,
429
+ # let's parse it to get our arguments.
430
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
431
+ else:
432
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
433
+
434
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
435
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
436
+ send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
437
+
438
+ # 2. Setup logging
439
+ logging.basicConfig(
440
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
441
+ datefmt="%m/%d/%Y %H:%M:%S",
442
+ handlers=[logging.StreamHandler(sys.stdout)],
443
+ )
444
+ log_level = training_args.get_process_log_level()
445
+ logger.setLevel(log_level)
446
+ datasets.utils.logging.set_verbosity(log_level)
447
+ transformers.utils.logging.set_verbosity(log_level)
448
+ transformers.utils.logging.enable_default_handler()
449
+ transformers.utils.logging.enable_explicit_format()
450
+
451
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
452
+
453
+ # Log on each process the small summary:
454
+ logger.warning(
455
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
456
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
457
+ )
458
+ logger.info(f"Training/evaluation parameters {training_args}")
459
+
460
+ # Set the verbosity to info of the Transformers logger (on main process only):
461
+ if is_main_process(training_args.local_rank):
462
+ transformers.utils.logging.set_verbosity_info()
463
+ logger.info("Training/evaluation parameters %s", training_args)
464
+
465
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
466
+ last_checkpoint = None
467
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
468
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
469
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
470
+ raise ValueError(
471
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
472
+ "Use --overwrite_output_dir to overcome."
473
+ )
474
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
475
+ logger.info(
476
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
477
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
478
+ )
479
+
480
+ # Set seed before initializing model.
481
+ set_seed(training_args.seed)
482
+
483
+ # 5. Load pretrained model, tokenizer, and feature extractor
484
+ #
485
+ # Distributed training:
486
+ # The .from_pretrained methods guarantee that only one local process can concurrently
487
+ config = AutoConfig.from_pretrained(
488
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
489
+ cache_dir=model_args.cache_dir,
490
+ revision=model_args.model_revision,
491
+ use_auth_token=True if model_args.use_auth_token else None,
492
+ )
493
+
494
+ config.update({ "forced_decoder_ids": model_args.forced_decoder_ids,
495
+ "suppress_tokens": model_args.suppress_tokens})
496
+
497
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
498
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
499
+ cache_dir=model_args.cache_dir,
500
+ revision=model_args.model_revision,
501
+ use_auth_token=True if model_args.use_auth_token else None,
502
+ )
503
+ tokenizer = AutoTokenizer.from_pretrained(
504
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
505
+ cache_dir=model_args.cache_dir,
506
+ use_fast=model_args.use_fast_tokenizer,
507
+ revision=model_args.model_revision,
508
+ use_auth_token=True if model_args.use_auth_token else None,
509
+ )
510
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
511
+ model_args.model_name_or_path,
512
+ config=config,
513
+ cache_dir=model_args.cache_dir,
514
+ revision=model_args.model_revision,
515
+ use_auth_token=True if model_args.use_auth_token else None,
516
+ )
517
+
518
+ model.config.use_cache = model_args.use_cache
519
+ model.config.dropout = model_args.dropout
520
+ model.config.attention_dropout = model_args.attention_dropout
521
+ if training_args.gradient_checkpointing:
522
+ model.gradient_checkpointing_enable()
523
+
524
+ if model.config.decoder_start_token_id is None:
525
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
526
+
527
+ # deprecated
528
+ #if model_args.freeze_feature_encoder:
529
+ # model.freeze_feature_encoder()
530
+
531
+ if model_args.freeze_encoder:
532
+ model.freeze_encoder()
533
+ model.model.encoder.gradient_checkpointing = False
534
+
535
+ if data_args.language is not None:
536
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
537
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
538
+
539
+
540
+ # 4. Load dataset
541
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
542
+
543
+ if training_args.do_train:
544
+ raw_datasets["train"] = load_multiple_streaming_datasets(
545
+ dataset_names=data_args.dataset_name.split(","),
546
+ dataset_config_names=data_args.dataset_config_name.split(","),
547
+ splits = data_args.train_split_name.split(","),
548
+ text_column_names = data_args.text_column_name.split(","),
549
+ sampling_rate = feature_extractor.sampling_rate,
550
+ streaming=data_args.streaming,
551
+ use_auth_token=True if model_args.use_auth_token else None,
552
+ )
553
+
554
+ if training_args.do_eval:
555
+ raw_datasets["eval"] = load_streaming_dataset(
556
+ data_args.evalset_name,
557
+ data_args.evalset_config_name,
558
+ text_column_name = data_args.eval_text_column_name,
559
+ split=data_args.eval_split_name,
560
+ use_auth_token=True if model_args.use_auth_token else None,
561
+ )
562
+
563
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
564
+
565
+ if AUDIO_COL_NAME not in raw_datasets_features:
566
+ raise ValueError(
567
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
568
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
569
+ f"{', '.join(raw_datasets_features)}."
570
+ )
571
+
572
+ if TEXT_COL_NAME not in raw_datasets_features:
573
+ raise ValueError(
574
+ f"--text_column_name {TEXT_COL_NAME} not found in dataset. "
575
+ "Make sure to set `--text_column_name` to the the respective correct text columns."
576
+ )
577
+
578
+ # 6. Resample eval common voice dataset as it has sampling_rate 48000
579
+ raw_datasets['eval'] = raw_datasets['eval'].cast_column("audio", datasets.features.Audio(sampling_rate=16000))
580
+
581
+
582
+ # 7. Preprocessing the datasets.
583
+ # We need to read the audio files as arrays and tokenize the targets.
584
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
585
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
586
+ audio_column_name = AUDIO_COL_NAME
587
+ text_column_name = TEXT_COL_NAME
588
+ model_input_name = feature_extractor.model_input_names[0]
589
+ do_lower_case = data_args.do_lower_case
590
+ do_remove_punctuation = data_args.do_remove_punctuation
591
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
592
+
593
+ if data_args.max_train_samples is not None:
594
+ raw_datasets["train"] = (
595
+ raw_datasets["train"].take(data_args.max_train_samples)
596
+ if data_args.streaming
597
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
598
+ )
599
+
600
+ if data_args.max_eval_samples is not None:
601
+ raw_datasets["eval"] = (
602
+ raw_datasets["eval"].take(data_args.max_eval_samples)
603
+ if data_args.streaming
604
+ else raw_datasets["eval"].select(range(data_args.max_eval_samples))
605
+ )
606
+
607
+ def prepare_dataset(batch):
608
+ # process audio
609
+ sample = batch[audio_column_name]
610
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
611
+ # process audio length
612
+ batch[model_input_name] = inputs.get(model_input_name)[0]
613
+ batch["input_length"] = len(sample["array"])
614
+
615
+ # process targets
616
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
617
+ if do_remove_punctuation:
618
+ input_str = normalizer(input_str).strip()
619
+ batch["labels"] = tokenizer(input_str).input_ids
620
+
621
+ # compute labels length **with** special tokens! -> total label length
622
+ batch["labels_length"] = len(batch["labels"])
623
+
624
+ return batch
625
+
626
+
627
+ ts_tokenizer = WhisperTSTokenizer("openai/whisper-medium", language="sv")
628
+
629
+ def prepare_train_dataset(batch):
630
+ # load and (possibly) resample audio data to 16kHz
631
+ audio = batch["audio"]
632
+
633
+ # compute log-Mel input features from input audio array
634
+ batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
635
+ # compute input length of audio sample in seconds
636
+ batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
637
+
638
+ batch["labels"] = ts_tokenizer(batch["text"]).input_ids
639
+
640
+ # compute labels length **with** special tokens! -> total label length
641
+ batch["labels_length"] = len(batch["labels"])
642
+
643
+ return batch
644
+
645
+
646
+ def prepare_default_dataset(batch):
647
+ # process audio
648
+ sample = batch[audio_column_name]
649
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
650
+ # process audio length
651
+ batch[model_input_name] = inputs.get(model_input_name)[0]
652
+ batch["input_length"] = len(sample["array"])
653
+
654
+ # process targets
655
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
656
+ if do_remove_punctuation:
657
+ input_str = normalizer(input_str).strip()
658
+ batch["labels"] = tokenizer(input_str).input_ids
659
+
660
+ # compute labels length **with** special tokens! -> total label length
661
+ batch["labels_length"] = len(batch["labels"])
662
+
663
+ return batch
664
+
665
+
666
+ with training_args.main_process_first(desc="dataset map pre-processing"):
667
+ eval_dataset = raw_datasets['eval'].map(
668
+ prepare_dataset,
669
+ remove_columns=list(raw_datasets['eval'].features),
670
+ ).with_format("torch")
671
+
672
+ if training_args.do_train and data_args.streaming:
673
+ train_dataset = raw_datasets['train'].map(
674
+ prepare_default_dataset,
675
+ remove_columns=list(raw_datasets['train'].features),
676
+ ).with_format("torch")
677
+
678
+
679
+ train_dataset = train_dataset.shuffle(
680
+ buffer_size=data_args.shuffle_buffer_size,
681
+ seed=training_args.seed,
682
+ )
683
+
684
+ # filter training data that is shorter than min_input_length or longer than
685
+ # max_input_length
686
+ def is_audio_in_length_range(length):
687
+ return min_input_length < length < max_input_length
688
+
689
+ max_label_length = model.config.max_length
690
+
691
+ def filter_labels(labels_length):
692
+ """Filter label sequences longer than max length (448)"""
693
+ return labels_length < max_label_length
694
+
695
+ if training_args.do_train:
696
+ train_dataset = train_dataset.filter(
697
+ is_audio_in_length_range,
698
+ input_columns=["input_length"],
699
+ )
700
+ train_dataset = train_dataset.filter(
701
+ filter_labels,
702
+ input_columns=["labels_length"],
703
+ )
704
+
705
+ # 8. Load Metric
706
+ metric = evaluate.load("wer")
707
+ do_normalize_eval = data_args.do_normalize_eval
708
+
709
+ def compute_metrics(pred):
710
+ pred_ids = pred.predictions
711
+
712
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
713
+
714
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
715
+ # we do not want to group tokens when computing the metrics
716
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
717
+
718
+ if do_normalize_eval:
719
+ pred_str = [normalizer(pred) for pred in pred_str]
720
+ label_str = [normalizer(label) for label in label_str]
721
+ # filtering step to only evaluate the samples that correspond to non-zero references:
722
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
723
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
724
+
725
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
726
+
727
+ return {"wer": wer}
728
+
729
+ # 9. Create a single speech processor
730
+ if is_main_process(training_args.local_rank):
731
+ # save feature extractor, tokenizer and config
732
+ feature_extractor.save_pretrained(training_args.output_dir)
733
+ tokenizer.save_pretrained(training_args.output_dir)
734
+ config.save_pretrained(training_args.output_dir)
735
+
736
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
737
+
738
+ # 10. Define data collator
739
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
740
+ processor=processor,
741
+ decoder_start_token_id=model.config.decoder_start_token_id,
742
+ )
743
+
744
+ # 11. Configure Trainer
745
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
746
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
747
+ class ShuffleCallback(TrainerCallback):
748
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
749
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
750
+ pass # set_epoch() is handled by the Trainer
751
+ elif isinstance(train_dataloader.dataset, IterableDataset):
752
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
753
+
754
+ # Apply Spec Augment
755
+ def apply_spec_augment(batch):
756
+ freq_masking = T.FrequencyMasking(freq_mask_param=27)
757
+ time_masking = T.TimeMasking(time_mask_param=100)
758
+
759
+ spec_input = torch.from_numpy(batch[model_input_name])
760
+ spec_input_augmented = freq_masking(spec_input)
761
+ spec_input_augmented = time_masking(spec_input_augmented.reshape(1,80,-1))
762
+ batch[model_input_name] = spec_input_augmented.reshape(80,-1).numpy()
763
+ return batch
764
+
765
+ #train_dataset = train_dataset.map(apply_spec_augment)
766
+
767
+
768
+ # Create prefetched dataset
769
+ def prefetch_iterator_train():
770
+ for example in BackgroundGenerator(train_dataset, max_prefetch=1000):
771
+ yield example
772
+
773
+ def prefetch_iterator_eval():
774
+ for example in BackgroundGenerator(eval_dataset, max_prefetch=1000):
775
+ yield example
776
+
777
+ prefetched_ds_train = datasets.iterable_dataset.IterableDataset.from_generator(prefetch_iterator_train)
778
+ prefetched_ds_eval = datasets.iterable_dataset.IterableDataset.from_generator(prefetch_iterator_eval)
779
+
780
+ # Initialize Trainer
781
+ trainer = Seq2SeqTrainer(
782
+ model=model,
783
+ args=training_args,
784
+ train_dataset=prefetched_ds_train if training_args.do_train else None,
785
+ eval_dataset=prefetched_ds_eval if training_args.do_eval else None,
786
+ tokenizer=feature_extractor,
787
+ data_collator=data_collator,
788
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
789
+ callbacks=[ShuffleCallback()] if data_args.streaming else None,
790
+ )
791
+
792
+ # 12. Training
793
+ if training_args.do_train:
794
+ checkpoint = None
795
+ if training_args.resume_from_checkpoint is not None:
796
+ checkpoint = training_args.resume_from_checkpoint
797
+ elif last_checkpoint is not None:
798
+ checkpoint = last_checkpoint
799
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
800
+ trainer.save_model() # Saves the feature extractor too for easy upload
801
+
802
+ metrics = train_result.metrics
803
+ if data_args.max_train_samples:
804
+ metrics["train_samples"] = data_args.max_train_samples
805
+ trainer.log_metrics("train", metrics)
806
+ trainer.save_metrics("train", metrics)
807
+ trainer.save_state()
808
+
809
+ # 13. Evaluation
810
+ results = {}
811
+ if training_args.do_eval:
812
+ logger.info("*** Evaluate ***")
813
+ metrics = trainer.evaluate(
814
+ metric_key_prefix="eval",
815
+ max_length=training_args.generation_max_length,
816
+ num_beams=training_args.generation_num_beams,
817
+ )
818
+ if data_args.max_eval_samples:
819
+ metrics["eval_samples"] = data_args.max_eval_samples
820
+
821
+ trainer.log_metrics("eval", metrics)
822
+ trainer.save_metrics("eval", metrics)
823
+
824
+ # 14. Write Training Stats
825
+ kwargs = {
826
+ "finetuned_from": model_args.model_name_or_path,
827
+ "tasks": "automatic-speech-recognition",
828
+ "tags": "whisper-event",
829
+ }
830
+ if data_args.dataset_name is not None:
831
+ kwargs["dataset_tags"] = data_args.dataset_name
832
+ if data_args.dataset_config_name is not None:
833
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
834
+ else:
835
+ kwargs["dataset"] = data_args.dataset_name
836
+ if "common_voice" in data_args.dataset_name:
837
+ kwargs["language"] = data_args.dataset_config_name[:2]
838
+ if model_args.model_index_name is not None:
839
+ kwargs["model_name"] = model_args.model_index_name
840
+
841
+ if training_args.push_to_hub:
842
+ trainer.push_to_hub(**kwargs)
843
+ else:
844
+ trainer.create_model_card(**kwargs)
845
+
846
+ return results
847
+
848
+
849
+ class WhisperTSTokenizer():
850
+
851
+ def __init__(self, pretrained_model_name_or_path, language="sv", task="transcribe", *init_inputs, **kwargs):
852
+ self.tokenizer = WhisperTokenizer.from_pretrained(pretrained_model_name_or_path,
853
+ predict_timestamps=False,
854
+ language=language,
855
+ task=task,
856
+ *init_inputs,
857
+ **kwargs)
858
+
859
+ self.ts_tokenizer = WhisperTokenizer.from_pretrained(pretrained_model_name_or_path,
860
+ predict_timestamps=True,
861
+ language=language,
862
+ task=task,
863
+ *init_inputs,
864
+ **kwargs)
865
+
866
+ timestamp_tokens = ["<|{}|>".format(round(number, 2)) for number in np.arange (0.0, 30.02, 0.02)]
867
+ self.ts_tokenizer.add_tokens(timestamp_tokens)
868
+
869
+ def decode(self, input_tokens, output_offsets=True, skip_special_tokens=False):
870
+ return self.tokenizer.decode(input_tokens,
871
+ output_offsets=output_offsets,
872
+ skip_special_tokens=skip_special_tokens)
873
+
874
+ def __call__(self, example_text, return_timestamps = True):
875
+
876
+ if return_timestamps is False:
877
+ return self.tokenizer(example_text['text'])
878
+
879
+ segments_str = self._segments_to_str(example_text['offsets'])
880
+ results = self.ts_tokenizer(segments_str)
881
+ input_ids = []
882
+ attention_mask = []
883
+ for input_id, attention in zip(results.input_ids, results.attention_mask):
884
+ if input_id != 62:
885
+ input_ids.append(input_id)
886
+ attention_mask.append(attention)
887
+
888
+ results['input_ids'] = input_ids
889
+ results['attention_mask'] = attention_mask
890
+ return results
891
+
892
+ def _segments_to_str(self, segments):
893
+ result = ""
894
+ for segment in segments:
895
+ if len(result) > 0: result += " _ "
896
+ result += self._segment_to_str(segment)
897
+ result += " "
898
+ return result
899
+
900
+ def _segment_to_str(self, segment):
901
+ text = segment['text']
902
+ ts_start = segment['timestamp'][0]
903
+ ts_end = segment['timestamp'][1]
904
+ result = "<|{}|>_ {} <|{}|>".format(ts_start, text, ts_end)
905
+ result = result.replace(" ", " ")
906
+ return result
907
+
908
+
909
+
910
+ if __name__ == "__main__":
911
+ main()