File size: 4,681 Bytes
4ba35bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Nemo diarizer
"""
import os
import json

import wget
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from nemo.collections.asr.models import ClusteringDiarizer
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object
from pyannote.core import notebook

from diarizers.diarizer import Diarizer


class NemoDiarizer(Diarizer):
    """Class for Nemo Diarizer"""

    def __init__(self, audio_path: str, data_dir: str):
        """
        Nemo diarizer class
        Args:
            audio_path (str): the path to the audio file
        """
        self.audio_path = audio_path
        self.data_dir = data_dir
        self.diarization = None
        self.manifest_dir = os.path.join(self.data_dir, 'input_manifest.json')
        self.model_config = os.path.join(self.data_dir, 'offline_diarization.yaml')
        if not os.path.exists(self.model_config):
            config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/" \
                         "speaker_tasks/diarization/conf/offline_diarization.yaml"
            self.model_config = wget.download(config_url, self.data_dir)
        self.config = OmegaConf.load(self.model_config)

    def _create_manifest_file(self):
        """
        Function that creates inference manifest file
        """
        meta = {
            'audio_filepath': self.audio_path,
            'offset': 0,
            'duration': None,
            'label': 'infer',
            'text': '-',
            'num_speakers': None,
            'rttm_filepath': None,
            'uem_filepath': None
        }
        with open(self.manifest_dir, 'w') as fp:
            json.dump(meta, fp)
            fp.write('\n')

    def _apply_config(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
        """
        Function that edits the inference configuration file
        Args:
            pretrained_speaker_model (str): the pre-trained embedding model options are
            ('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
            https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
            speaker_diarization/results.html
        """

        pretrained_vad = 'vad_marblenet'

        output_dir = os.path.join(self.data_dir, 'outputs')
        self.config.diarizer.manifest_filepath = self.manifest_dir
        self.config.diarizer.out_dir = output_dir
        self.config.diarizer.ignore_overlap = False

        self.config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
        self.config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5
        self.config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75
        self.config.diarizer.oracle_vad = False
        self.config.diarizer.clustering.parameters.oracle_num_speakers = False

        # Here we use our inhouse pretrained NeMo VAD
        self.config.diarizer.vad.model_path = pretrained_vad
        self.config.diarizer.vad.window_length_in_sec = 0.15
        self.config.diarizer.vad.shift_length_in_sec = 0.01
        self.config.diarizer.vad.parameters.onset = 0.8
        self.config.diarizer.vad.parameters.offset = 0.6
        self.config.diarizer.vad.parameters.min_duration_on = 0.1
        self.config.diarizer.vad.parameters.min_duration_off = 0.4

    def diarize_audio(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
        """
        function that diarizes the audio
        Args:
            pretrained_speaker_model (str): the pre-trained embedding model options are
            ('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
            https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
            speaker_diarization/results.html
        """
        self._create_manifest_file()
        self._apply_config(pretrained_speaker_model)
        sd_model = ClusteringDiarizer(cfg=self.config)
        sd_model.diarize()
        audio_file_name_without_extension = os.path.basename(self.audio_path).rsplit('.', 1)[0]
        output_diarization_pred = f'{self.data_dir}/outputs/pred_rttms/' \
                                  f'{audio_file_name_without_extension}.rttm'
        pred_labels = rttm_to_labels(output_diarization_pred)
        self.diarization = labels_to_pyannote_object(pred_labels)

    def get_diarization_figure(self) -> plt.gcf:
        """
        Function that return the diarization figure
        """
        if not self.diarization:
            self.diarize_audio()
        figure, ax = plt.subplots()
        notebook.plot_annotation(self.diarization, ax=ax, time=True, legend=True)
        return plt.gcf()