Spaces:
Running
Running
darksakura
commited on
Commit
•
ba49987
1
Parent(s):
93d7220
Upload RMVPEF0Predictor.py
Browse files- RMVPEF0Predictor.py +107 -0
RMVPEF0Predictor.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from modules.F0Predictor.F0Predictor import F0Predictor
|
8 |
+
|
9 |
+
from .rmvpe import RMVPE
|
10 |
+
|
11 |
+
|
12 |
+
class RMVPEF0Predictor(F0Predictor):
|
13 |
+
def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
|
14 |
+
self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
|
15 |
+
self.hop_length = hop_length
|
16 |
+
self.f0_min = f0_min
|
17 |
+
self.f0_max = f0_max
|
18 |
+
if device is None:
|
19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
#self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
else:
|
22 |
+
self.dev = torch.device("cpu")
|
23 |
+
self.threshold = threshold
|
24 |
+
self.sampling_rate = sampling_rate
|
25 |
+
self.dtype = dtype
|
26 |
+
|
27 |
+
def repeat_expand(
|
28 |
+
self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
|
29 |
+
):
|
30 |
+
ndim = content.ndim
|
31 |
+
|
32 |
+
if content.ndim == 1:
|
33 |
+
content = content[None, None]
|
34 |
+
elif content.ndim == 2:
|
35 |
+
content = content[None]
|
36 |
+
|
37 |
+
assert content.ndim == 3
|
38 |
+
|
39 |
+
is_np = isinstance(content, np.ndarray)
|
40 |
+
if is_np:
|
41 |
+
content = torch.from_numpy(content)
|
42 |
+
|
43 |
+
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
|
44 |
+
|
45 |
+
if is_np:
|
46 |
+
results = results.numpy()
|
47 |
+
|
48 |
+
if ndim == 1:
|
49 |
+
return results[0, 0]
|
50 |
+
elif ndim == 2:
|
51 |
+
return results[0]
|
52 |
+
|
53 |
+
def post_process(self, x, sampling_rate, f0, pad_to):
|
54 |
+
if isinstance(f0, np.ndarray):
|
55 |
+
f0 = torch.from_numpy(f0).float().to(x.device)
|
56 |
+
|
57 |
+
if pad_to is None:
|
58 |
+
return f0
|
59 |
+
|
60 |
+
f0 = self.repeat_expand(f0, pad_to)
|
61 |
+
|
62 |
+
vuv_vector = torch.zeros_like(f0)
|
63 |
+
vuv_vector[f0 > 0.0] = 1.0
|
64 |
+
vuv_vector[f0 <= 0.0] = 0.0
|
65 |
+
|
66 |
+
# 去掉0频率, 并线性插值
|
67 |
+
nzindex = torch.nonzero(f0).squeeze()
|
68 |
+
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
69 |
+
time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
|
70 |
+
time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
|
71 |
+
|
72 |
+
vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
|
73 |
+
|
74 |
+
if f0.shape[0] <= 0:
|
75 |
+
return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
|
76 |
+
if f0.shape[0] == 1:
|
77 |
+
return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
|
78 |
+
|
79 |
+
# 大概可以用 torch 重写?
|
80 |
+
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
|
81 |
+
#vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
|
82 |
+
|
83 |
+
return f0,vuv_vector.cpu().numpy()
|
84 |
+
|
85 |
+
def compute_f0(self,wav,p_len=None):
|
86 |
+
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
87 |
+
if p_len is None:
|
88 |
+
p_len = x.shape[0]//self.hop_length
|
89 |
+
else:
|
90 |
+
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
91 |
+
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
|
92 |
+
if torch.all(f0 == 0):
|
93 |
+
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
|
94 |
+
return rtn,rtn
|
95 |
+
return self.post_process(x,self.sampling_rate,f0,p_len)[0]
|
96 |
+
|
97 |
+
def compute_f0_uv(self,wav,p_len=None):
|
98 |
+
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
99 |
+
if p_len is None:
|
100 |
+
p_len = x.shape[0]//self.hop_length
|
101 |
+
else:
|
102 |
+
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
103 |
+
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
|
104 |
+
if torch.all(f0 == 0):
|
105 |
+
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
|
106 |
+
return rtn,rtn
|
107 |
+
return self.post_process(x,self.sampling_rate,f0,p_len)
|