--- license: apache-2.0 language: - sl - hr - sr base_model: - facebook/w2v-bert-2.0 pipeline_tag: audio-classification metrics: - f1 --- # Frame classification for filled pauses This model classifies individual 20ms frames of audio based on presence of filled pauses ("eee", "errm", ...). It was trained on human-annotated Slovenian speech corpus ROG-Artur and achieves F1 of 0.952868 on the test split of the same dataset. Evaluation on 800 human-annotated instances ParlaSpeech-HR and ParlaSpeech-RS produced the following metrics: ``` Performance on RS: Classification report for human vs model on event level: precision recall f1-score support 0 0.97 0.87 0.92 234 1 0.95 0.99 0.97 542 accuracy 0.95 776 macro avg 0.96 0.93 0.94 776 weighted avg 0.95 0.95 0.95 776 Performance on HR: Classification report for human vs model on event level: precision recall f1-score support 0 0.94 0.84 0.89 242 1 0.93 0.98 0.95 531 accuracy 0.93 773 macro avg 0.93 0.91 0.92 773 weighted avg 0.93 0.93 0.93 773 ``` The metrics reported are on event level, which means that if true and predicted filled pauses at least partially overlap, we count them as a True Positive event. # Example use: ```python from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification from datasets import Dataset, Audio import torch import numpy as np from pathlib import Path device = torch.device("cuda") model_name = "5roop/wav2vecbert2-filledPause" feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device) ds = Dataset.from_dict( { "audio": [ "/cache/peterr/mezzanine_resources/filled_pauses/data/dev/Iriss-J-Gvecg-P500001-avd_2082.293_2112.194.wav" ], } ).cast_column("audio", Audio(sampling_rate=16_000, mono=True)) def evaluator(chunks): sampling_rate = chunks["audio"][0]["sampling_rate"] with torch.no_grad(): inputs = feature_extractor( [i["array"] for i in chunks["audio"]], return_tensors="pt", sampling_rate=sampling_rate, ).to(device) logits = model(**inputs).logits y_pred = np.array(logits.cpu()).argmax(axis=-1) return {"y_pred": y_pred.tolist()} ds = ds.map(evaluator, batched=True) print(ds["y_pred"][0]) # Returns a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,....] # with 0 indicating no filled pause detected in that frame ``` # Citation Coming soon.