|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
This module aims to be an entrance that integrates all the functions for extracting features from raw audio. |
|
|
|
The common audio features include: |
|
1. Acoustic features such as Mel Spectrogram, F0, Energy, etc. |
|
2. Content features such as phonetic posteriorgrams (PPG) and bottleneck features (BNF) from pretrained models |
|
|
|
Note: |
|
All the features extraction are designed to utilize GPU to the maximum extent, which can ease the on-the-fly extraction for large-scale dataset. |
|
|
|
""" |
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from utils.mel import extract_mel_features |
|
from utils.f0 import get_f0 as extract_f0_features |
|
from processors.content_extractor import ( |
|
WhisperExtractor, |
|
ContentvecExtractor, |
|
WenetExtractor, |
|
) |
|
|
|
|
|
class AudioFeaturesExtractor: |
|
def __init__(self, cfg): |
|
""" |
|
Args: |
|
cfg: Amphion config that would be used to specify the processing parameters |
|
""" |
|
self.cfg = cfg |
|
|
|
def get_mel_spectrogram(self, wavs): |
|
"""Get Mel Spectrogram Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
|
|
Returns: |
|
Tensor whose shape is (B, n_mels, n_frames) |
|
""" |
|
return extract_mel_features(y=wavs, cfg=self.cfg.preprocess) |
|
|
|
def get_f0(self, wavs, wav_lens=None, use_interpolate=False, return_uv=False): |
|
"""Get F0 Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
|
|
Returns: |
|
Tensor whose shape is (B, n_frames) |
|
""" |
|
device = wavs.device |
|
|
|
f0s = [] |
|
uvs = [] |
|
for i, w in enumerate(wavs): |
|
if wav_lens is not None: |
|
w = w[: wav_lens[i]] |
|
|
|
f0, uv = extract_f0_features( |
|
|
|
w.cpu().numpy(), |
|
self.cfg.preprocess, |
|
use_interpolate=use_interpolate, |
|
return_uv=True, |
|
) |
|
f0s.append(torch.as_tensor(f0, device=device)) |
|
uvs.append(torch.as_tensor(uv, device=device, dtype=torch.long)) |
|
|
|
|
|
f0s = pad_sequence(f0s, batch_first=True, padding_value=0) |
|
uvs = pad_sequence(uvs, batch_first=True, padding_value=0) |
|
|
|
if return_uv: |
|
return f0s, uvs |
|
|
|
return f0s |
|
|
|
def get_energy(self, wavs, mel_spec=None): |
|
"""Get Energy Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
mel_spec: Tensor whose shape is (B, n_mels, n_frames) |
|
|
|
Returns: |
|
Tensor whose shape is (B, n_frames) |
|
""" |
|
if mel_spec is None: |
|
mel_spec = self.get_mel_spectrogram(wavs) |
|
|
|
energies = (mel_spec.exp() ** 2).sum(dim=1).sqrt() |
|
return energies |
|
|
|
def get_whisper_features(self, wavs, target_frame_len): |
|
"""Get Whisper Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
target_frame_len: int |
|
|
|
Returns: |
|
Tensor whose shape is (B, target_frame_len, D) |
|
""" |
|
if not hasattr(self, "whisper_extractor"): |
|
self.whisper_extractor = WhisperExtractor(self.cfg) |
|
self.whisper_extractor.load_model() |
|
|
|
whisper_feats = self.whisper_extractor.extract_content_features(wavs) |
|
whisper_feats = self.whisper_extractor.ReTrans(whisper_feats, target_frame_len) |
|
return whisper_feats |
|
|
|
def get_contentvec_features(self, wavs, target_frame_len): |
|
"""Get ContentVec Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
target_frame_len: int |
|
|
|
Returns: |
|
Tensor whose shape is (B, target_frame_len, D) |
|
""" |
|
if not hasattr(self, "contentvec_extractor"): |
|
self.contentvec_extractor = ContentvecExtractor(self.cfg) |
|
self.contentvec_extractor.load_model() |
|
|
|
contentvec_feats = self.contentvec_extractor.extract_content_features(wavs) |
|
contentvec_feats = self.contentvec_extractor.ReTrans( |
|
contentvec_feats, target_frame_len |
|
) |
|
return contentvec_feats |
|
|
|
def get_wenet_features(self, wavs, target_frame_len, wav_lens=None): |
|
"""Get WeNet Features |
|
|
|
Args: |
|
wavs: Tensor whose shape is (B, T) |
|
target_frame_len: int |
|
wav_lens: Tensor whose shape is (B) |
|
|
|
Returns: |
|
Tensor whose shape is (B, target_frame_len, D) |
|
""" |
|
if not hasattr(self, "wenet_extractor"): |
|
self.wenet_extractor = WenetExtractor(self.cfg) |
|
self.wenet_extractor.load_model() |
|
|
|
wenet_feats = self.wenet_extractor.extract_content_features(wavs, lens=wav_lens) |
|
wenet_feats = self.wenet_extractor.ReTrans(wenet_feats, target_frame_len) |
|
return wenet_feats |
|
|