darksakura commited on
Commit
ba49987
1 Parent(s): 93d7220

Upload RMVPEF0Predictor.py

Browse files
Files changed (1) hide show
  1. RMVPEF0Predictor.py +107 -0
RMVPEF0Predictor.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from modules.F0Predictor.F0Predictor import F0Predictor
8
+
9
+ from .rmvpe import RMVPE
10
+
11
+
12
+ class RMVPEF0Predictor(F0Predictor):
13
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
14
+ self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
15
+ self.hop_length = hop_length
16
+ self.f0_min = f0_min
17
+ self.f0_max = f0_max
18
+ if device is None:
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ #self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ else:
22
+ self.dev = torch.device("cpu")
23
+ self.threshold = threshold
24
+ self.sampling_rate = sampling_rate
25
+ self.dtype = dtype
26
+
27
+ def repeat_expand(
28
+ self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
29
+ ):
30
+ ndim = content.ndim
31
+
32
+ if content.ndim == 1:
33
+ content = content[None, None]
34
+ elif content.ndim == 2:
35
+ content = content[None]
36
+
37
+ assert content.ndim == 3
38
+
39
+ is_np = isinstance(content, np.ndarray)
40
+ if is_np:
41
+ content = torch.from_numpy(content)
42
+
43
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
44
+
45
+ if is_np:
46
+ results = results.numpy()
47
+
48
+ if ndim == 1:
49
+ return results[0, 0]
50
+ elif ndim == 2:
51
+ return results[0]
52
+
53
+ def post_process(self, x, sampling_rate, f0, pad_to):
54
+ if isinstance(f0, np.ndarray):
55
+ f0 = torch.from_numpy(f0).float().to(x.device)
56
+
57
+ if pad_to is None:
58
+ return f0
59
+
60
+ f0 = self.repeat_expand(f0, pad_to)
61
+
62
+ vuv_vector = torch.zeros_like(f0)
63
+ vuv_vector[f0 > 0.0] = 1.0
64
+ vuv_vector[f0 <= 0.0] = 0.0
65
+
66
+ # 去掉0频率, 并线性插值
67
+ nzindex = torch.nonzero(f0).squeeze()
68
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
69
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
70
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
71
+
72
+ vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
73
+
74
+ if f0.shape[0] <= 0:
75
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
76
+ if f0.shape[0] == 1:
77
+ return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
78
+
79
+ # 大概可以用 torch 重写?
80
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
81
+ #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
82
+
83
+ return f0,vuv_vector.cpu().numpy()
84
+
85
+ def compute_f0(self,wav,p_len=None):
86
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
87
+ if p_len is None:
88
+ p_len = x.shape[0]//self.hop_length
89
+ else:
90
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
91
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
92
+ if torch.all(f0 == 0):
93
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
94
+ return rtn,rtn
95
+ return self.post_process(x,self.sampling_rate,f0,p_len)[0]
96
+
97
+ def compute_f0_uv(self,wav,p_len=None):
98
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
99
+ if p_len is None:
100
+ p_len = x.shape[0]//self.hop_length
101
+ else:
102
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
103
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
104
+ if torch.all(f0 == 0):
105
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
106
+ return rtn,rtn
107
+ return self.post_process(x,self.sampling_rate,f0,p_len)