[WIP] Upload folder using huggingface_hub (multi-commit 513e361cfa57256a8e357aacb50b4a5aa4111aae72a68ff40bac250e9ec1f525)

#2
Files changed (4) hide show
  1. README.md +71 -12
  2. app.py +45 -158
  3. hparams.py +0 -167
  4. synthesis.py +0 -66
README.md CHANGED
@@ -1,12 +1,71 @@
1
- ---
2
- title: Tts Rvc Autopst
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global Prosody Style Transfer Without Text Transcriptions
2
+
3
+ This repository provides a PyTorch implementation of [AutoPST](https://arxiv.org/abs/2106.08519), which enables unsupervised global prosody conversion without text transcriptions.
4
+
5
+ This is a short video that explains the main concepts of our work. If you find this work useful and use it in your research, please consider citing our paper.
6
+
7
+ [![SpeechSplit](./assets/cover.png)](https://youtu.be/wow2DRuJ69c/)
8
+
9
+ ```
10
+ @InProceedings{pmlr-v139-qian21b,
11
+ title = {Global Prosody Style Transfer Without Text Transcriptions},
12
+ author = {Qian, Kaizhi and Zhang, Yang and Chang, Shiyu and Xiong, Jinjun and Gan, Chuang and Cox, David and Hasegawa-Johnson, Mark},
13
+ booktitle = {Proceedings of the 38th International Conference on Machine Learning},
14
+ pages = {8650--8660},
15
+ year = {2021},
16
+ editor = {Meila, Marina and Zhang, Tong},
17
+ volume = {139},
18
+ series = {Proceedings of Machine Learning Research},
19
+ month = {18--24 Jul},
20
+ publisher = {PMLR},
21
+ url = {http://proceedings.mlr.press/v139/qian21b.html}
22
+ }
23
+
24
+ ```
25
+
26
+
27
+ ## Audio Demo
28
+
29
+ The audio demo for AutoPST can be found [here](https://auspicious3000.github.io/AutoPST-Demo/)
30
+
31
+ ## Dependencies
32
+ - Python 3.6
33
+ - Numpy
34
+ - Scipy
35
+ - PyTorch == v1.6.0
36
+ - librosa
37
+ - pysptk
38
+ - soundfile
39
+ - wavenet_vocoder ```pip install wavenet_vocoder==0.1.1```
40
+ for more information, please refer to https://github.com/r9y9/wavenet_vocoder
41
+
42
+
43
+ ## To Run Demo
44
+
45
+ Download [pre-trained models](https://drive.google.com/file/d/1ji3Bk6YGvXkPqFu1hLOAJp_SKw-vHGrp/view?usp=sharing) to ```assets```
46
+
47
+ Download the same WaveNet vocoder model as in [AutoVC](https://github.com/auspicious3000/autovc) to ```assets```
48
+
49
+ The fast and high-quality hifi-gan v1 (https://github.com/jik876/hifi-gan) pre-trained model is now available [here.](https://drive.google.com/file/d/1n76jHs8k1sDQ3Eh5ajXwdxuY_EZw4N9N/view?usp=sharing)
50
+
51
+ Please refer to [AutoVC](https://github.com/auspicious3000/autovc) if you have any problems with the vocoder part, because they share the same vocoder scripts.
52
+
53
+ Run ```demo.ipynb```
54
+
55
+
56
+ ## To Train
57
+
58
+ Download [training data](https://drive.google.com/file/d/1H1dyA80qREKLHybqnYaqBRRsacIdFbnE/view?usp=sharing) to ```assets```.
59
+ The provided training data is very small for code verification purpose only.
60
+ Please use the scripts to prepare your own data for training.
61
+
62
+ 1. Prepare training data: ```python prepare_train_data.py```
63
+
64
+ 2. Train 1st Stage: ```python main_1.py```
65
+
66
+ 3. Train 2nd Stage: ```python main_2.py```
67
+
68
+
69
+ ## Final Words
70
+
71
+ This project is part of an ongoing research. We hope this repo is useful for your research. If you need any help or have any suggestions on improving the framework, please raise an issue and we will do our best to get back to you as soon as possible.
app.py CHANGED
@@ -13,11 +13,9 @@ import numpy as np
13
  import torch
14
  import torch.nn.functional as F
15
  from collections import OrderedDict
16
- from onmt_modules.misc import sequence_mask
17
- from model_autopst import Generator_2 as Predictor
18
- from hparams_autopst import hparams
19
- from model_sea import Generator
20
- from hparams_sea import hparams as sea_hparams
21
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
@@ -27,7 +25,7 @@ checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", file
27
  P.load_state_dict(checkpoint['model'], strict=True)
28
  print('Loaded predictor .....................................................')
29
 
30
- dict_test = pickle.load(open('./assets/test_vctk.meta', 'rb'))
31
 
32
  spect_vc = OrderedDict()
33
 
@@ -66,17 +64,13 @@ import torch
66
  import librosa
67
  import pickle
68
  import os
69
- from synthesis import build_model
70
- from synthesis import wavegen
71
 
72
  model = build_model().to(device)
73
  checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
74
  model.load_state_dict(checkpoint["state_dict"])
75
 
76
- # sea_checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='sea.ckpt'), map_location=lambda storage, loc: storage)
77
- # gen =Generator(sea_hparams)
78
- # gen.load_state_dict(sea_checkpoint['model'], strict=True)
79
-
80
  # for name, sp in spect_vc.items():
81
 
82
  # print(name)
@@ -87,164 +81,57 @@ model.load_state_dict(checkpoint["state_dict"])
87
 
88
 
89
 
90
- # def respond(
91
- # message,
92
- # history: list[tuple[str, str]],
93
- # system_message,
94
- # max_tokens,
95
- # temperature,
96
- # top_p,
97
- # ):
98
- # messages = [{"role": "system", "content": system_message}]
99
 
100
- # for val in history:
101
- # if val[0]:
102
- # messages.append({"role": "user", "content": val[0]})
103
- # if val[1]:
104
- # messages.append({"role": "assistant", "content": val[1]})
105
 
106
- # messages.append({"role": "user", "content": message})
107
 
108
- # response = ""
109
 
110
- # for message in client.chat_completion(
111
- # messages,
112
- # max_tokens=max_tokens,
113
- # stream=True,
114
- # temperature=temperature,
115
- # top_p=top_p,
116
- # ):
117
- # token = message.choices[0].delta.content
118
 
119
- # response += token
120
- # yield response
121
 
122
  """
123
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
124
  """
125
- # demo = gr.ChatInterface(
126
- # respond,
127
- # additional_inputs=[
128
- # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
129
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
130
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
131
- # gr.Slider(
132
- # minimum=0.1,
133
- # maximum=1.0,
134
- # value=0.95,
135
- # step=0.05,
136
- # label="Top-p (nucleus sampling)",
137
- # ),
138
- # ],
139
- # )
140
-
141
- import os
142
- import pickle
143
- import numpy as np
144
- import soundfile as sf
145
- from scipy import signal
146
- from scipy.signal import get_window
147
- from librosa.filters import mel
148
- from numpy.random import RandomState
149
-
150
-
151
- def butter_highpass(cutoff, fs, order=5):
152
- nyq = 0.5 * fs
153
- normal_cutoff = cutoff / nyq
154
- b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
155
- return b, a
156
-
157
-
158
- def pySTFT(x, fft_length=1024, hop_length=256):
159
-
160
- x = np.pad(x, int(fft_length//2), mode='reflect')
161
-
162
- noverlap = fft_length - hop_length
163
- shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
164
- strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
165
- result = np.lib.stride_tricks.as_strided(x, shape=shape,
166
- strides=strides)
167
-
168
- fft_window = get_window('hann', fft_length, fftbins=True)
169
- result = np.fft.rfft(fft_window * result, n=fft_length).T
170
-
171
- return np.abs(result)
172
-
173
-
174
- def create_sp(cep_real, spk_emb):
175
- # cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
176
- cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
177
- len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
178
- real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
179
-
180
- # _, spk_emb = dict_test[uttr[1]][uttr[2]]
181
- spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
182
-
183
- with torch.no_grad():
184
- spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
185
- real_mask_A,
186
- len_real_A,
187
- spk_emb_B)
188
-
189
- uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
190
- return uttr_tgt
191
-
192
- def create_mel(x):
193
- mel_basis = mel(sr=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80).T
194
- min_level = np.exp(-100 / 20 * np.log(10))
195
- b, a = butter_highpass(30, 16000, order=5)
196
-
197
- mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
198
- spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
199
-
200
- if x.shape[0] % 256 == 0:
201
- x = np.concatenate((x, np.array([1e-06])), axis=0)
202
- y = signal.filtfilt(b, a, x)
203
- D = pySTFT(y * 0.96).T
204
- D_mel = np.dot(D, mel_basis)
205
- D_db = 20 * np.log10(np.maximum(min_level, D_mel))
206
-
207
- # mel sp
208
- S = (D_db + 80) / 100
209
-
210
- # mel cep
211
- cc_tmp = S.dot(dctmx)
212
- cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
213
- S = np.clip(S, 0, 1)
214
-
215
- # teacher code
216
- # cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
217
- # with torch.no_grad():
218
- # codes = gen.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
219
- return S, cc_norm
220
-
221
- def transcribe(audio, spk):
222
- sr, y = audio
223
- y = librosa.resample(y, orig_sr=sr, target_sr=16000)
224
- y = y.astype(np.float32)
225
- y /= np.max(np.abs(y))
226
-
227
- spk_emb = np.zeros((82,))
228
- spk_emb[int(spk)-1] = 1
229
-
230
- mel_sp, mel_cep = create_mel(y)
231
- sp = create_sp(mel_cep, spk_emb)
232
- waveform = wavegen(model, c=sp)
233
- return 16000, waveform
234
-
235
- # return transcriber({"sampling_rate": sr, "raw": y})["text"]
236
-
237
-
238
- demo = gr.Interface(
239
- transcribe,
240
- [
241
- gr.Audio(),
242
- gr.Slider(1, 82, value=21, label="Count", step=1, info="Choose between 1 and 82")
243
  ],
244
- "audio",
245
  )
246
 
247
 
248
-
249
  if __name__ == "__main__":
250
  demo.launch()
 
13
  import torch
14
  import torch.nn.functional as F
15
  from collections import OrderedDict
16
+ from AutoPST.onmt_modules.misc import sequence_mask
17
+ from AutoPST.model_autopst import Generator_2 as Predictor
18
+ from AutoPST.hparams_autopst import hparams
 
 
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
25
  P.load_state_dict(checkpoint['model'], strict=True)
26
  print('Loaded predictor .....................................................')
27
 
28
+ dict_test = pickle.load(open('./AutoPST/assets/test_vctk.meta', 'rb'))
29
 
30
  spect_vc = OrderedDict()
31
 
 
64
  import librosa
65
  import pickle
66
  import os
67
+ from AutoPST.synthesis import build_model
68
+ from AutoPST.synthesis import wavegen
69
 
70
  model = build_model().to(device)
71
  checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
72
  model.load_state_dict(checkpoint["state_dict"])
73
 
 
 
 
 
74
  # for name, sp in spect_vc.items():
75
 
76
  # print(name)
 
81
 
82
 
83
 
84
+ def respond(
85
+ message,
86
+ history: list[tuple[str, str]],
87
+ system_message,
88
+ max_tokens,
89
+ temperature,
90
+ top_p,
91
+ ):
92
+ messages = [{"role": "system", "content": system_message}]
93
 
94
+ for val in history:
95
+ if val[0]:
96
+ messages.append({"role": "user", "content": val[0]})
97
+ if val[1]:
98
+ messages.append({"role": "assistant", "content": val[1]})
99
 
100
+ messages.append({"role": "user", "content": message})
101
 
102
+ response = ""
103
 
104
+ for message in client.chat_completion(
105
+ messages,
106
+ max_tokens=max_tokens,
107
+ stream=True,
108
+ temperature=temperature,
109
+ top_p=top_p,
110
+ ):
111
+ token = message.choices[0].delta.content
112
 
113
+ response += token
114
+ yield response
115
 
116
  """
117
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
118
  """
119
+ demo = gr.ChatInterface(
120
+ respond,
121
+ additional_inputs=[
122
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
123
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
124
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
125
+ gr.Slider(
126
+ minimum=0.1,
127
+ maximum=1.0,
128
+ value=0.95,
129
+ step=0.05,
130
+ label="Top-p (nucleus sampling)",
131
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ],
 
133
  )
134
 
135
 
 
136
  if __name__ == "__main__":
137
  demo.launch()
hparams.py DELETED
@@ -1,167 +0,0 @@
1
- class Map(dict):
2
- """
3
- Example:
4
- m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
5
-
6
- Credits to epool:
7
- https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
8
- """
9
-
10
- def __init__(self, *args, **kwargs):
11
- super(Map, self).__init__(*args, **kwargs)
12
- for arg in args:
13
- if isinstance(arg, dict):
14
- for k, v in arg.items():
15
- self[k] = v
16
-
17
- if kwargs:
18
- for k, v in kwargs.iteritems():
19
- self[k] = v
20
-
21
- def __getattr__(self, attr):
22
- return self.get(attr)
23
-
24
- def __setattr__(self, key, value):
25
- self.__setitem__(key, value)
26
-
27
- def __setitem__(self, key, value):
28
- super(Map, self).__setitem__(key, value)
29
- self.__dict__.update({key: value})
30
-
31
- def __delattr__(self, item):
32
- self.__delitem__(item)
33
-
34
- def __delitem__(self, key):
35
- super(Map, self).__delitem__(key)
36
- del self.__dict__[key]
37
-
38
-
39
- # Default hyperparameters:
40
- hparams = Map({
41
- 'name': "wavenet_vocoder",
42
-
43
- # Convenient model builder
44
- 'builder': "wavenet",
45
-
46
- # Input type:
47
- # 1. raw [-1, 1]
48
- # 2. mulaw [-1, 1]
49
- # 3. mulaw-quantize [0, mu]
50
- # If input_type is raw or mulaw, network assumes scalar input and
51
- # discretized mixture of logistic distributions output, otherwise one-hot
52
- # input and softmax output are assumed.
53
- # **NOTE**: if you change the one of the two parameters below, you need to
54
- # re-run preprocessing before training.
55
- 'input_type': "raw",
56
- 'quantize_channels': 65536, # 65536 or 256
57
-
58
- # Audio:
59
- 'sample_rate': 16000,
60
- # this is only valid for mulaw is True
61
- 'silence_threshold': 2,
62
- 'num_mels': 80,
63
- 'fmin': 125,
64
- 'fmax': 7600,
65
- 'fft_size': 1024,
66
- # shift can be specified by either hop_size or frame_shift_ms
67
- 'hop_size': 256,
68
- 'frame_shift_ms': None,
69
- 'min_level_db': -100,
70
- 'ref_level_db': 20,
71
- # whether to rescale waveform or not.
72
- # Let x is an input waveform, rescaled waveform y is given by:
73
- # y = x / np.abs(x).max() * rescaling_max
74
- 'rescaling': True,
75
- 'rescaling_max': 0.999,
76
- # mel-spectrogram is normalized to [0, 1] for each utterance and clipping may
77
- # happen depends on min_level_db and ref_level_db, causing clipping noise.
78
- # If False, assertion is added to ensure no clipping happens.o0
79
- 'allow_clipping_in_normalization': True,
80
-
81
- # Mixture of logistic distributions:
82
- 'log_scale_min': float(-32.23619130191664),
83
-
84
- # Model:
85
- # This should equal to `quantize_channels` if mu-law quantize enabled
86
- # otherwise num_mixture * 3 (pi, mean, log_scale)
87
- 'out_channels': 10 * 3,
88
- 'layers': 24,
89
- 'stacks': 4,
90
- 'residual_channels': 512,
91
- 'gate_channels': 512, # split into 2 gropus internally for gated activation
92
- 'skip_out_channels': 256,
93
- 'dropout': 1 - 0.95,
94
- 'kernel_size': 3,
95
- # If True, apply weight normalization as same as DeepVoice3
96
- 'weight_normalization': True,
97
- # Use legacy code or not. Default is True since we already provided a model
98
- # based on the legacy code that can generate high-quality audio.
99
- # Ref: https://github.com/r9y9/wavenet_vocoder/pull/73
100
- 'legacy': True,
101
-
102
- # Local conditioning (set negative value to disable))
103
- 'cin_channels': 80,
104
- # If True, use transposed convolutions to upsample conditional features,
105
- # otherwise repeat features to adjust time resolution
106
- 'upsample_conditional_features': True,
107
- # should np.prod(upsample_scales) == hop_size
108
- 'upsample_scales': [4, 4, 4, 4],
109
- # Freq axis kernel size for upsampling network
110
- 'freq_axis_kernel_size': 3,
111
-
112
- # Global conditioning (set negative value to disable)
113
- # currently limited for speaker embedding
114
- # this should only be enabled for multi-speaker dataset
115
- 'gin_channels': -1, # i.e., speaker embedding dim
116
- 'n_speakers': -1,
117
-
118
- # Data loader
119
- 'pin_memory': True,
120
- 'num_workers': 2,
121
-
122
- # train/test
123
- # test size can be specified as portion or num samples
124
- 'test_size': 0.0441, # 50 for CMU ARCTIC single speaker
125
- 'test_num_samples': None,
126
- 'random_state': 1234,
127
-
128
- # Loss
129
-
130
- # Training:
131
- 'batch_size': 2,
132
- 'adam_beta1': 0.9,
133
- 'adam_beta2': 0.999,
134
- 'adam_eps': 1e-8,
135
- 'amsgrad': False,
136
- 'initial_learning_rate': 1e-3,
137
- # see lrschedule.py for available lr_schedule
138
- 'lr_schedule': "noam_learning_rate_decay",
139
- 'lr_schedule_kwargs': {}, # {"anneal_rate": 0.5, "anneal_interval": 50000},
140
- 'nepochs': 2000,
141
- 'weight_decay': 0.0,
142
- 'clip_thresh': -1,
143
- # max time steps can either be specified as sec or steps
144
- # if both are None, then full audio samples are used in a batch
145
- 'max_time_sec': None,
146
- 'max_time_steps': 8000,
147
- # Hold moving averaged parameters and use them for evaluation
148
- 'exponential_moving_average': True,
149
- # averaged = decay * averaged + (1 - decay) * x
150
- 'ema_decay': 0.9999,
151
-
152
- # Save
153
- # per-step intervals
154
- 'checkpoint_interval': 10000,
155
- 'train_eval_interval': 10000,
156
- # per-epoch interval
157
- 'test_eval_epoch_interval': 5,
158
- 'save_optimizer_state': True,
159
-
160
- # Eval:
161
- })
162
-
163
-
164
- def hparams_debug_string():
165
- values = hparams.values()
166
- hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
167
- return 'Hyperparameters:\n' + '\n'.join(hp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
synthesis.py DELETED
@@ -1,66 +0,0 @@
1
- import torch
2
- from tqdm import tqdm
3
- import librosa
4
- from hparams import hparams
5
- from wavenet_vocoder import builder
6
-
7
- torch.set_num_threads(4)
8
- use_cuda = torch.cuda.is_available()
9
- device = torch.device("cuda" if use_cuda else "cpu")
10
-
11
-
12
- def build_model():
13
-
14
- model = getattr(builder, hparams.builder)(
15
- out_channels=hparams.out_channels,
16
- layers=hparams.layers,
17
- stacks=hparams.stacks,
18
- residual_channels=hparams.residual_channels,
19
- gate_channels=hparams.gate_channels,
20
- skip_out_channels=hparams.skip_out_channels,
21
- cin_channels=hparams.cin_channels,
22
- gin_channels=hparams.gin_channels,
23
- weight_normalization=hparams.weight_normalization,
24
- n_speakers=hparams.n_speakers,
25
- dropout=hparams.dropout,
26
- kernel_size=hparams.kernel_size,
27
- upsample_conditional_features=hparams.upsample_conditional_features,
28
- upsample_scales=hparams.upsample_scales,
29
- freq_axis_kernel_size=hparams.freq_axis_kernel_size,
30
- scalar_input=True,
31
- legacy=hparams.legacy,
32
- )
33
- return model
34
-
35
-
36
-
37
- def wavegen(model, c=None, tqdm=tqdm):
38
- """Generate waveform samples by WaveNet.
39
-
40
- """
41
-
42
- model.eval()
43
- model.make_generation_fast_()
44
-
45
- Tc = c.shape[0]
46
- upsample_factor = hparams.hop_size
47
- # Overwrite length according to feature size
48
- length = Tc * upsample_factor
49
-
50
- # B x C x T
51
- c = torch.FloatTensor(c.T).unsqueeze(0)
52
-
53
- initial_input = torch.zeros(1, 1, 1).fill_(0.0)
54
-
55
- # Transform data to GPU
56
- initial_input = initial_input.to(device)
57
- c = None if c is None else c.to(device)
58
-
59
- with torch.no_grad():
60
- y_hat = model.incremental_forward(
61
- initial_input, c=c, g=None, T=length, tqdm=tqdm, softmax=True, quantize=True,
62
- log_scale_min=hparams.log_scale_min)
63
-
64
- y_hat = y_hat.view(-1).cpu().data.numpy()
65
-
66
- return y_hat