|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram |
|
|
|
from .hparams import HParams |
|
|
|
|
|
class MelSpectrogram(nn.Module): |
|
def __init__(self, hp: HParams): |
|
""" |
|
Torch implementation of Resemble's mel extraction. |
|
Note that the values are NOT identical to librosa's implementation |
|
due to floating point precisions. |
|
""" |
|
super().__init__() |
|
self.hp = hp |
|
self.melspec = TorchMelSpectrogram( |
|
hp.wav_rate, |
|
n_fft=hp.n_fft, |
|
win_length=hp.win_size, |
|
hop_length=hp.hop_size, |
|
f_min=0, |
|
f_max=hp.wav_rate // 2, |
|
n_mels=hp.num_mels, |
|
power=1, |
|
normalized=False, |
|
|
|
pad_mode="constant", |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])) |
|
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min) |
|
self.preemphasis = hp.preemphasis |
|
self.hop_size = hp.hop_size |
|
|
|
def forward(self, wav, pad=True): |
|
""" |
|
Args: |
|
wav: [B, T] |
|
""" |
|
device = wav.device |
|
if wav.is_mps: |
|
wav = wav.cpu() |
|
self.to(wav.device) |
|
if self.preemphasis > 0: |
|
wav = torch.nn.functional.pad(wav, [1, 0], value=0) |
|
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1] |
|
mel = self.melspec(wav) |
|
mel = self._amp_to_db(mel) |
|
mel_normed = self._normalize(mel) |
|
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size |
|
mel_normed = mel_normed.to(device) |
|
return mel_normed |
|
|
|
def _normalize(self, s, headroom_db=15): |
|
return (s - self.min_level_db) / (-self.min_level_db + headroom_db) |
|
|
|
def _amp_to_db(self, x): |
|
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20 |
|
|