Modifications to support "v1.5"
Browse files- do_tts.py +40 -32
- models/autoregressive.py +18 -3
- models/{discrete_diffusion_vocoder.py → diffusion_decoder.py} +305 -175
- models/dvae.py +0 -390
- models/vocoder.py +325 -0
- requirements.txt +2 -1
- utils/audio.py +85 -1
- utils/diffusion.py +18 -0
- utils/stft.py +193 -0
do_tts.py
CHANGED
@@ -8,14 +8,14 @@ import torch.nn.functional as F
|
|
8 |
import torchaudio
|
9 |
import progressbar
|
10 |
|
11 |
-
from models.
|
12 |
from models.autoregressive import UnifiedVoice
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
from models.arch_util import TorchMelSpectrogram
|
16 |
-
from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
|
17 |
from models.text_voice_clip import VoiceCLIP
|
18 |
-
from
|
|
|
19 |
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
20 |
from utils.tokenizer import VoiceBpeTokenizer
|
21 |
|
@@ -23,7 +23,6 @@ pbar = None
|
|
23 |
def download_models():
|
24 |
MODELS = {
|
25 |
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
26 |
-
'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin',
|
27 |
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
28 |
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
29 |
}
|
@@ -47,12 +46,14 @@ def download_models():
|
|
47 |
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
48 |
print('Done.')
|
49 |
|
|
|
50 |
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
51 |
"""
|
52 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
53 |
"""
|
54 |
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
55 |
-
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps)
|
|
|
56 |
|
57 |
|
58 |
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
@@ -94,26 +95,26 @@ def fix_autoregressive_output(codes, stop_token):
|
|
94 |
return codes
|
95 |
|
96 |
|
97 |
-
def do_spectrogram_diffusion(diffusion_model,
|
98 |
"""
|
99 |
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
100 |
"""
|
101 |
with torch.no_grad():
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
dsl = 2048 // spectrogram_compression_factor
|
107 |
gap = dsl - (msl % dsl)
|
108 |
if gap > 0:
|
109 |
-
mel = torch.nn.functional.pad(
|
110 |
|
111 |
-
output_shape = (mel.shape[0],
|
112 |
if mean:
|
113 |
-
|
114 |
-
model_kwargs={'
|
115 |
else:
|
116 |
-
|
|
|
117 |
|
118 |
|
119 |
if __name__ == '__main__':
|
@@ -145,12 +146,6 @@ if __name__ == '__main__':
|
|
145 |
download_models()
|
146 |
|
147 |
for voice in args.voice.split(','):
|
148 |
-
print("Loading GPT TTS..")
|
149 |
-
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
|
150 |
-
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
|
151 |
-
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
152 |
-
stop_mel_token = autoregressive.stop_mel_token
|
153 |
-
|
154 |
print("Loading data..")
|
155 |
tokenizer = VoiceBpeTokenizer()
|
156 |
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
@@ -160,7 +155,15 @@ if __name__ == '__main__':
|
|
160 |
for cond_path in cond_paths:
|
161 |
c, cond_wav = load_conditioning(cond_path)
|
162 |
conds.append(c)
|
163 |
-
conds = torch.stack(conds, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
with torch.no_grad():
|
166 |
print("Performing autoregressive inference..")
|
@@ -194,20 +197,25 @@ if __name__ == '__main__':
|
|
194 |
# Delete the autoregressive and clip models to free up GPU memory
|
195 |
del samples, clip
|
196 |
|
197 |
-
print("Loading DVAE..")
|
198 |
-
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
|
199 |
-
record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
|
200 |
-
dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)
|
201 |
print("Loading Diffusion Model..")
|
202 |
-
diffusion =
|
203 |
-
|
204 |
-
|
|
|
|
|
205 |
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
207 |
|
208 |
print("Performing vocoding..")
|
209 |
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
|
210 |
for b in range(best_results.shape[0]):
|
211 |
code = best_results[b].unsqueeze(0)
|
212 |
-
|
213 |
-
|
|
|
|
8 |
import torchaudio
|
9 |
import progressbar
|
10 |
|
11 |
+
from models.diffusion_decoder import DiffusionTts
|
12 |
from models.autoregressive import UnifiedVoice
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
from models.arch_util import TorchMelSpectrogram
|
|
|
16 |
from models.text_voice_clip import VoiceCLIP
|
17 |
+
from models.vocoder import UnivNetGenerator
|
18 |
+
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
19 |
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
20 |
from utils.tokenizer import VoiceBpeTokenizer
|
21 |
|
|
|
23 |
def download_models():
|
24 |
MODELS = {
|
25 |
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
|
|
26 |
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
27 |
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
28 |
}
|
|
|
46 |
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
47 |
print('Done.')
|
48 |
|
49 |
+
|
50 |
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
51 |
"""
|
52 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
53 |
"""
|
54 |
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
55 |
+
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
56 |
+
conditioning_free=True, conditioning_free_k=1)
|
57 |
|
58 |
|
59 |
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
|
|
95 |
return codes
|
96 |
|
97 |
|
98 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
|
99 |
"""
|
100 |
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
101 |
"""
|
102 |
with torch.no_grad():
|
103 |
+
cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
|
104 |
+
# Pad MEL to multiples of 32
|
105 |
+
msl = mel_codes.shape[-1]
|
106 |
+
dsl = 32
|
|
|
107 |
gap = dsl - (msl % dsl)
|
108 |
if gap > 0:
|
109 |
+
mel = torch.nn.functional.pad(mel_codes, (0, gap))
|
110 |
|
111 |
+
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
112 |
if mean:
|
113 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
|
114 |
+
model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
115 |
else:
|
116 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
117 |
+
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
118 |
|
119 |
|
120 |
if __name__ == '__main__':
|
|
|
146 |
download_models()
|
147 |
|
148 |
for voice in args.voice.split(','):
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
print("Loading data..")
|
150 |
tokenizer = VoiceBpeTokenizer()
|
151 |
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
|
|
155 |
for cond_path in cond_paths:
|
156 |
c, cond_wav = load_conditioning(cond_path)
|
157 |
conds.append(c)
|
158 |
+
conds = torch.stack(conds, dim=1)
|
159 |
+
cond_diffusion = cond_wav[:, :88200] # The diffusion model expects <= 88200 conditioning samples.
|
160 |
+
|
161 |
+
print("Loading GPT TTS..")
|
162 |
+
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
|
163 |
+
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False,
|
164 |
+
average_conditioning_embeddings=True).cuda().eval()
|
165 |
+
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
166 |
+
stop_mel_token = autoregressive.stop_mel_token
|
167 |
|
168 |
with torch.no_grad():
|
169 |
print("Performing autoregressive inference..")
|
|
|
197 |
# Delete the autoregressive and clip models to free up GPU memory
|
198 |
del samples, clip
|
199 |
|
|
|
|
|
|
|
|
|
200 |
print("Loading Diffusion Model..")
|
201 |
+
diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
|
202 |
+
channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], token_conditioning_resolutions=[1,4,8],
|
203 |
+
dropout=0, attention_resolutions=[4,8], num_heads=8, kernel_size=3, scale_factor=2,
|
204 |
+
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
|
205 |
+
conditioning_expansion=1)
|
206 |
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
207 |
+
diffusion = diffusion.cuda().eval()
|
208 |
+
print("Loading vocoder..")
|
209 |
+
vocoder = UnivNetGenerator()
|
210 |
+
vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
211 |
+
vocoder = vocoder.cuda()
|
212 |
+
vocoder.eval(inference=True)
|
213 |
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
214 |
|
215 |
print("Performing vocoding..")
|
216 |
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
|
217 |
for b in range(best_results.shape[0]):
|
218 |
code = best_results[b].unsqueeze(0)
|
219 |
+
mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
|
220 |
+
wav = vocoder.inference(mel)
|
221 |
+
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)
|
models/autoregressive.py
CHANGED
@@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
|
|
192 |
embedding_dim,
|
193 |
attn_blocks=6,
|
194 |
num_attn_heads=4,
|
195 |
-
do_checkpointing=False
|
|
|
196 |
super().__init__()
|
197 |
attn = []
|
198 |
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
@@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module):
|
|
201 |
self.attn = nn.Sequential(*attn)
|
202 |
self.dim = embedding_dim
|
203 |
self.do_checkpointing = do_checkpointing
|
|
|
204 |
|
205 |
def forward(self, x):
|
206 |
h = self.init(x)
|
207 |
h = self.attn(h)
|
208 |
-
|
|
|
|
|
|
|
209 |
|
210 |
|
211 |
class LearnedPositionEmbeddings(nn.Module):
|
@@ -275,7 +280,7 @@ class UnifiedVoice(nn.Module):
|
|
275 |
mel_length_compression=1024, number_text_tokens=256,
|
276 |
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
277 |
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
278 |
-
checkpointing=True):
|
279 |
"""
|
280 |
Args:
|
281 |
layers: Number of layers in transformer stack.
|
@@ -294,6 +299,7 @@ class UnifiedVoice(nn.Module):
|
|
294 |
train_solo_embeddings:
|
295 |
use_mel_codes_as_input:
|
296 |
checkpointing:
|
|
|
297 |
"""
|
298 |
super().__init__()
|
299 |
|
@@ -311,6 +317,7 @@ class UnifiedVoice(nn.Module):
|
|
311 |
self.max_conditioning_inputs = max_conditioning_inputs
|
312 |
self.mel_length_compression = mel_length_compression
|
313 |
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
|
|
314 |
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
315 |
if use_mel_codes_as_input:
|
316 |
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
@@ -408,6 +415,8 @@ class UnifiedVoice(nn.Module):
|
|
408 |
for j in range(speech_conditioning_input.shape[1]):
|
409 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
410 |
conds = torch.stack(conds, dim=1)
|
|
|
|
|
411 |
|
412 |
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
413 |
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
@@ -446,6 +455,8 @@ class UnifiedVoice(nn.Module):
|
|
446 |
for j in range(speech_conditioning_input.shape[1]):
|
447 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
448 |
conds = torch.stack(conds, dim=1)
|
|
|
|
|
449 |
|
450 |
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
451 |
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
@@ -472,6 +483,8 @@ class UnifiedVoice(nn.Module):
|
|
472 |
for j in range(speech_conditioning_input.shape[1]):
|
473 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
474 |
conds = torch.stack(conds, dim=1)
|
|
|
|
|
475 |
|
476 |
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
477 |
if raw_mels is not None:
|
@@ -508,6 +521,8 @@ class UnifiedVoice(nn.Module):
|
|
508 |
for j in range(speech_conditioning_input.shape[1]):
|
509 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
510 |
conds = torch.stack(conds, dim=1)
|
|
|
|
|
511 |
|
512 |
emb = torch.cat([conds, text_emb], dim=1)
|
513 |
self.inference_model.store_mel_emb(emb)
|
|
|
192 |
embedding_dim,
|
193 |
attn_blocks=6,
|
194 |
num_attn_heads=4,
|
195 |
+
do_checkpointing=False,
|
196 |
+
mean=False):
|
197 |
super().__init__()
|
198 |
attn = []
|
199 |
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
|
|
202 |
self.attn = nn.Sequential(*attn)
|
203 |
self.dim = embedding_dim
|
204 |
self.do_checkpointing = do_checkpointing
|
205 |
+
self.mean = mean
|
206 |
|
207 |
def forward(self, x):
|
208 |
h = self.init(x)
|
209 |
h = self.attn(h)
|
210 |
+
if self.mean:
|
211 |
+
return h.mean(dim=2)
|
212 |
+
else:
|
213 |
+
return h[:, :, 0]
|
214 |
|
215 |
|
216 |
class LearnedPositionEmbeddings(nn.Module):
|
|
|
280 |
mel_length_compression=1024, number_text_tokens=256,
|
281 |
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
282 |
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
283 |
+
checkpointing=True, average_conditioning_embeddings=False):
|
284 |
"""
|
285 |
Args:
|
286 |
layers: Number of layers in transformer stack.
|
|
|
299 |
train_solo_embeddings:
|
300 |
use_mel_codes_as_input:
|
301 |
checkpointing:
|
302 |
+
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
|
303 |
"""
|
304 |
super().__init__()
|
305 |
|
|
|
317 |
self.max_conditioning_inputs = max_conditioning_inputs
|
318 |
self.mel_length_compression = mel_length_compression
|
319 |
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
320 |
+
self.average_conditioning_embeddings = average_conditioning_embeddings
|
321 |
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
322 |
if use_mel_codes_as_input:
|
323 |
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
|
|
415 |
for j in range(speech_conditioning_input.shape[1]):
|
416 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
417 |
conds = torch.stack(conds, dim=1)
|
418 |
+
if self.average_conditioning_embeddings:
|
419 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
420 |
|
421 |
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
422 |
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
|
|
455 |
for j in range(speech_conditioning_input.shape[1]):
|
456 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
457 |
conds = torch.stack(conds, dim=1)
|
458 |
+
if self.average_conditioning_embeddings:
|
459 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
460 |
|
461 |
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
462 |
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
|
|
483 |
for j in range(speech_conditioning_input.shape[1]):
|
484 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
485 |
conds = torch.stack(conds, dim=1)
|
486 |
+
if self.average_conditioning_embeddings:
|
487 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
488 |
|
489 |
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
490 |
if raw_mels is not None:
|
|
|
521 |
for j in range(speech_conditioning_input.shape[1]):
|
522 |
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
523 |
conds = torch.stack(conds, dim=1)
|
524 |
+
if self.average_conditioning_embeddings:
|
525 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
526 |
|
527 |
emb = torch.cat([conds, text_emb], dim=1)
|
528 |
self.inference_model.store_mel_emb(emb)
|
models/{discrete_diffusion_vocoder.py → diffusion_decoder.py}
RENAMED
@@ -3,17 +3,36 @@ This model is based on OpenAI's UNet from improved diffusion, with modifications
|
|
3 |
and an audio conditioning input. It has also been simplified somewhat.
|
4 |
Credit: https://github.com/openai/improved-diffusion
|
5 |
"""
|
6 |
-
|
7 |
-
|
8 |
import math
|
9 |
from abc import abstractmethod
|
10 |
|
11 |
import torch
|
12 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def timestep_embedding(timesteps, dim, max_period=10000):
|
18 |
"""
|
19 |
Create sinusoidal timestep embeddings.
|
@@ -62,63 +81,36 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
62 |
return x
|
63 |
|
64 |
|
65 |
-
class
|
66 |
-
"""
|
67 |
-
A residual block that can optionally change the number of channels.
|
68 |
-
|
69 |
-
:param channels: the number of input channels.
|
70 |
-
:param emb_channels: the number of timestep embedding channels.
|
71 |
-
:param dropout: the rate of dropout.
|
72 |
-
:param out_channels: if specified, the number of out channels.
|
73 |
-
:param use_conv: if True and out_channels is specified, use a spatial
|
74 |
-
convolution instead of a smaller 1x1 convolution to change the
|
75 |
-
channels in the skip connection.
|
76 |
-
:param dims: determines if the signal is 1D, 2D, or 3D.
|
77 |
-
:param up: if True, use this block for upsampling.
|
78 |
-
:param down: if True, use this block for downsampling.
|
79 |
-
"""
|
80 |
-
|
81 |
def __init__(
|
82 |
self,
|
83 |
channels,
|
84 |
emb_channels,
|
85 |
dropout,
|
86 |
out_channels=None,
|
87 |
-
use_conv=False,
|
88 |
-
use_scale_shift_norm=False,
|
89 |
-
up=False,
|
90 |
-
down=False,
|
91 |
kernel_size=3,
|
|
|
|
|
92 |
):
|
93 |
super().__init__()
|
94 |
self.channels = channels
|
95 |
self.emb_channels = emb_channels
|
96 |
self.dropout = dropout
|
97 |
self.out_channels = out_channels or channels
|
98 |
-
self.use_conv = use_conv
|
99 |
self.use_scale_shift_norm = use_scale_shift_norm
|
100 |
-
padding = 1
|
|
|
|
|
101 |
|
102 |
self.in_layers = nn.Sequential(
|
103 |
normalization(channels),
|
104 |
nn.SiLU(),
|
105 |
-
nn.Conv1d(channels, self.out_channels,
|
106 |
)
|
107 |
|
108 |
-
self.updown = up or down
|
109 |
-
|
110 |
-
if up:
|
111 |
-
self.h_upd = Upsample(channels, False, dims)
|
112 |
-
self.x_upd = Upsample(channels, False, dims)
|
113 |
-
elif down:
|
114 |
-
self.h_upd = Downsample(channels, False, dims)
|
115 |
-
self.x_upd = Downsample(channels, False, dims)
|
116 |
-
else:
|
117 |
-
self.h_upd = self.x_upd = nn.Identity()
|
118 |
-
|
119 |
self.emb_layers = nn.Sequential(
|
120 |
nn.SiLU(),
|
121 |
-
|
122 |
emb_channels,
|
123 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
124 |
),
|
@@ -134,22 +126,23 @@ class TimestepResBlock(TimestepBlock):
|
|
134 |
|
135 |
if self.out_channels == channels:
|
136 |
self.skip_connection = nn.Identity()
|
137 |
-
elif use_conv:
|
138 |
-
self.skip_connection = nn.Conv1d(
|
139 |
-
channels, self.out_channels, kernel_size, padding=padding
|
140 |
-
)
|
141 |
else:
|
142 |
-
self.skip_connection = nn.Conv1d(channels, self.out_channels,
|
143 |
|
144 |
def forward(self, x, emb):
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
153 |
emb_out = self.emb_layers(emb).type(h.dtype)
|
154 |
while len(emb_out.shape) < len(h.shape):
|
155 |
emb_out = emb_out[..., None]
|
@@ -164,37 +157,52 @@ class TimestepResBlock(TimestepBlock):
|
|
164 |
return self.skip_connection(x) + h
|
165 |
|
166 |
|
167 |
-
class
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
super().__init__()
|
170 |
-
self.
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
175 |
|
|
|
|
|
176 |
"""
|
177 |
-
|
178 |
-
|
179 |
-
:param x: bxcxS waveform latent
|
180 |
-
:param codes: bxN discrete codes, N <= S
|
181 |
"""
|
182 |
-
def
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
-
class
|
191 |
"""
|
192 |
The full UNet model with attention and timestep embedding.
|
193 |
|
194 |
-
Customized to be conditioned on
|
|
|
195 |
|
196 |
:param in_channels: channels in the input Tensor.
|
197 |
-
:param
|
198 |
:param model_channels: base channel count for the model.
|
199 |
:param out_channels: channels in the output Tensor.
|
200 |
:param num_res_blocks: number of residual blocks per downsample.
|
@@ -206,7 +214,6 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
206 |
:param channel_mult: channel multiplier for each level of the UNet.
|
207 |
:param conv_resample: if True, use learned convolutions for upsampling and
|
208 |
downsampling.
|
209 |
-
:param dims: determines if the signal is 1D, 2D, or 3D.
|
210 |
:param num_heads: the number of attention heads in each attention layer.
|
211 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
212 |
a fixed channel width per attention head.
|
@@ -222,34 +229,43 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
222 |
self,
|
223 |
model_channels,
|
224 |
in_channels=1,
|
|
|
|
|
|
|
|
|
225 |
out_channels=2, # mean and variance
|
226 |
-
dvae_dim=512,
|
227 |
dropout=0,
|
228 |
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
229 |
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
230 |
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
231 |
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
232 |
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
233 |
-
|
234 |
attention_resolutions=(512,1024,2048),
|
235 |
conv_resample=True,
|
236 |
-
dims=1,
|
237 |
use_fp16=False,
|
238 |
num_heads=1,
|
239 |
num_head_channels=-1,
|
240 |
num_heads_upsample=-1,
|
241 |
-
use_scale_shift_norm=False,
|
242 |
-
resblock_updown=False,
|
243 |
kernel_size=3,
|
244 |
scale_factor=2,
|
245 |
-
conditioning_inputs_provided=True,
|
246 |
time_embed_dim_multiplier=4,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
):
|
248 |
super().__init__()
|
249 |
|
250 |
if num_heads_upsample == -1:
|
251 |
num_heads_upsample = num_heads
|
252 |
|
|
|
|
|
253 |
self.in_channels = in_channels
|
254 |
self.model_channels = model_channels
|
255 |
self.out_channels = out_channels
|
@@ -257,53 +273,110 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
257 |
self.dropout = dropout
|
258 |
self.channel_mult = channel_mult
|
259 |
self.conv_resample = conv_resample
|
260 |
-
self.dtype = torch.float16 if use_fp16 else torch.float32
|
261 |
self.num_heads = num_heads
|
262 |
self.num_head_channels = num_head_channels
|
263 |
self.num_heads_upsample = num_heads_upsample
|
264 |
-
self.
|
265 |
-
|
|
|
|
|
|
|
|
|
266 |
padding = 1 if kernel_size == 3 else 2
|
|
|
267 |
|
268 |
time_embed_dim = model_channels * time_embed_dim_multiplier
|
269 |
self.time_embed = nn.Sequential(
|
270 |
-
|
271 |
nn.SiLU(),
|
272 |
-
|
273 |
)
|
274 |
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
nn.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
)
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
self._feature_size = model_channels
|
287 |
input_block_chans = [model_channels]
|
288 |
ch = model_channels
|
289 |
ds = 1
|
290 |
|
291 |
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
292 |
-
if ds in
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
|
298 |
for _ in range(num_blocks):
|
299 |
layers = [
|
300 |
-
|
301 |
ch,
|
302 |
time_embed_dim,
|
303 |
dropout,
|
304 |
out_channels=int(mult * model_channels),
|
305 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
306 |
kernel_size=kernel_size,
|
|
|
|
|
307 |
)
|
308 |
]
|
309 |
ch = int(mult * model_channels)
|
@@ -315,54 +388,44 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
315 |
num_head_channels=num_head_channels,
|
316 |
)
|
317 |
)
|
318 |
-
|
319 |
-
layer.level = 2 ** level
|
320 |
-
self.input_blocks.append(layer)
|
321 |
self._feature_size += ch
|
322 |
input_block_chans.append(ch)
|
323 |
if level != len(channel_mult) - 1:
|
324 |
out_ch = ch
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
dropout,
|
330 |
-
out_channels=out_ch,
|
331 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
332 |
-
down=True,
|
333 |
-
kernel_size=kernel_size,
|
334 |
-
)
|
335 |
-
if resblock_updown
|
336 |
-
else Downsample(
|
337 |
-
ch, conv_resample, out_channels=out_ch, factor=scale_factor
|
338 |
)
|
339 |
)
|
340 |
-
|
341 |
-
self.input_blocks.append(upblk)
|
342 |
ch = out_ch
|
343 |
input_block_chans.append(ch)
|
344 |
ds *= 2
|
345 |
self._feature_size += ch
|
346 |
|
347 |
self.middle_block = TimestepEmbedSequential(
|
348 |
-
|
349 |
ch,
|
350 |
time_embed_dim,
|
351 |
dropout,
|
352 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
353 |
kernel_size=kernel_size,
|
|
|
|
|
354 |
),
|
355 |
AttentionBlock(
|
356 |
ch,
|
357 |
num_heads=num_heads,
|
358 |
num_head_channels=num_head_channels,
|
359 |
),
|
360 |
-
|
361 |
ch,
|
362 |
time_embed_dim,
|
363 |
dropout,
|
364 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
365 |
kernel_size=kernel_size,
|
|
|
|
|
366 |
),
|
367 |
)
|
368 |
self._feature_size += ch
|
@@ -372,13 +435,14 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
372 |
for i in range(num_blocks + 1):
|
373 |
ich = input_block_chans.pop()
|
374 |
layers = [
|
375 |
-
|
376 |
ch + ich,
|
377 |
time_embed_dim,
|
378 |
dropout,
|
379 |
out_channels=int(model_channels * mult),
|
380 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
381 |
kernel_size=kernel_size,
|
|
|
|
|
382 |
)
|
383 |
]
|
384 |
ch = int(model_channels * mult)
|
@@ -393,22 +457,10 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
393 |
if level and i == num_blocks:
|
394 |
out_ch = ch
|
395 |
layers.append(
|
396 |
-
|
397 |
-
ch,
|
398 |
-
time_embed_dim,
|
399 |
-
dropout,
|
400 |
-
out_channels=out_ch,
|
401 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
402 |
-
up=True,
|
403 |
-
kernel_size=kernel_size,
|
404 |
-
)
|
405 |
-
if resblock_updown
|
406 |
-
else Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
|
407 |
)
|
408 |
ds //= 2
|
409 |
-
|
410 |
-
layer.level = 2 ** level
|
411 |
-
self.output_blocks.append(layer)
|
412 |
self._feature_size += ch
|
413 |
|
414 |
self.out = nn.Sequential(
|
@@ -417,52 +469,130 @@ class DiscreteDiffusionVocoder(nn.Module):
|
|
417 |
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
|
418 |
)
|
419 |
|
420 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
"""
|
422 |
Apply the model to an input batch.
|
423 |
|
424 |
:param x: an [N x C x ...] Tensor of inputs.
|
425 |
:param timesteps: a 1-D batch of timesteps.
|
426 |
-
:param
|
|
|
|
|
|
|
427 |
:return: an [N x C x ...] Tensor of outputs.
|
428 |
"""
|
429 |
-
assert
|
430 |
-
if self.
|
431 |
-
assert
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
else:
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
|
456 |
-
# Test for ~4 second audio clip at 22050Hz
|
457 |
if __name__ == '__main__':
|
458 |
-
clip = torch.randn(2, 1,
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
and an audio conditioning input. It has also been simplified somewhat.
|
4 |
Credit: https://github.com/openai/improved-diffusion
|
5 |
"""
|
6 |
+
import functools
|
|
|
7 |
import math
|
8 |
from abc import abstractmethod
|
9 |
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import autocast
|
14 |
+
from torch.nn import Linear
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
from x_transformers import ContinuousTransformerWrapper, Encoder
|
17 |
|
18 |
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
|
19 |
|
20 |
|
21 |
+
def is_latent(t):
|
22 |
+
return t.dtype == torch.float
|
23 |
+
|
24 |
+
|
25 |
+
def is_sequence(t):
|
26 |
+
return t.dtype == torch.long
|
27 |
+
|
28 |
+
|
29 |
+
def ceil_multiple(base, multiple):
|
30 |
+
res = base % multiple
|
31 |
+
if res == 0:
|
32 |
+
return base
|
33 |
+
return base + (multiple - res)
|
34 |
+
|
35 |
+
|
36 |
def timestep_embedding(timesteps, dim, max_period=10000):
|
37 |
"""
|
38 |
Create sinusoidal timestep embeddings.
|
|
|
81 |
return x
|
82 |
|
83 |
|
84 |
+
class ResBlock(TimestepBlock):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def __init__(
|
86 |
self,
|
87 |
channels,
|
88 |
emb_channels,
|
89 |
dropout,
|
90 |
out_channels=None,
|
|
|
|
|
|
|
|
|
91 |
kernel_size=3,
|
92 |
+
efficient_config=True,
|
93 |
+
use_scale_shift_norm=False,
|
94 |
):
|
95 |
super().__init__()
|
96 |
self.channels = channels
|
97 |
self.emb_channels = emb_channels
|
98 |
self.dropout = dropout
|
99 |
self.out_channels = out_channels or channels
|
|
|
100 |
self.use_scale_shift_norm = use_scale_shift_norm
|
101 |
+
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
102 |
+
eff_kernel = 1 if efficient_config else 3
|
103 |
+
eff_padding = 0 if efficient_config else 1
|
104 |
|
105 |
self.in_layers = nn.Sequential(
|
106 |
normalization(channels),
|
107 |
nn.SiLU(),
|
108 |
+
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
|
109 |
)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
self.emb_layers = nn.Sequential(
|
112 |
nn.SiLU(),
|
113 |
+
Linear(
|
114 |
emb_channels,
|
115 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
116 |
),
|
|
|
126 |
|
127 |
if self.out_channels == channels:
|
128 |
self.skip_connection = nn.Identity()
|
|
|
|
|
|
|
|
|
129 |
else:
|
130 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
131 |
|
132 |
def forward(self, x, emb):
|
133 |
+
"""
|
134 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
135 |
+
|
136 |
+
:param x: an [N x C x ...] Tensor of features.
|
137 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
138 |
+
:return: an [N x C x ...] Tensor of outputs.
|
139 |
+
"""
|
140 |
+
return checkpoint(
|
141 |
+
self._forward, x, emb
|
142 |
+
)
|
143 |
+
|
144 |
+
def _forward(self, x, emb):
|
145 |
+
h = self.in_layers(x)
|
146 |
emb_out = self.emb_layers(emb).type(h.dtype)
|
147 |
while len(emb_out.shape) < len(h.shape):
|
148 |
emb_out = emb_out[..., None]
|
|
|
157 |
return self.skip_connection(x) + h
|
158 |
|
159 |
|
160 |
+
class CheckpointedLayer(nn.Module):
|
161 |
+
"""
|
162 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
163 |
+
checkpoint for all other args.
|
164 |
+
"""
|
165 |
+
def __init__(self, wrap):
|
166 |
super().__init__()
|
167 |
+
self.wrap = wrap
|
168 |
+
|
169 |
+
def forward(self, x, *args, **kwargs):
|
170 |
+
for k, v in kwargs.items():
|
171 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
172 |
+
partial = functools.partial(self.wrap, **kwargs)
|
173 |
+
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
174 |
|
175 |
+
|
176 |
+
class CheckpointedXTransformerEncoder(nn.Module):
|
177 |
"""
|
178 |
+
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
179 |
+
to channels-last that XTransformer expects.
|
|
|
|
|
180 |
"""
|
181 |
+
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
182 |
+
super().__init__()
|
183 |
+
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
184 |
+
self.needs_permute = needs_permute
|
185 |
+
|
186 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
187 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
188 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
189 |
+
|
190 |
+
def forward(self, x, **kwargs):
|
191 |
+
if self.needs_permute:
|
192 |
+
x = x.permute(0,2,1)
|
193 |
+
h = self.transformer(x, **kwargs)
|
194 |
+
return h.permute(0,2,1)
|
195 |
|
196 |
|
197 |
+
class DiffusionTts(nn.Module):
|
198 |
"""
|
199 |
The full UNet model with attention and timestep embedding.
|
200 |
|
201 |
+
Customized to be conditioned on an aligned prior derived from a autoregressive
|
202 |
+
GPT-style model.
|
203 |
|
204 |
:param in_channels: channels in the input Tensor.
|
205 |
+
:param in_latent_channels: channels from the input latent.
|
206 |
:param model_channels: base channel count for the model.
|
207 |
:param out_channels: channels in the output Tensor.
|
208 |
:param num_res_blocks: number of residual blocks per downsample.
|
|
|
214 |
:param channel_mult: channel multiplier for each level of the UNet.
|
215 |
:param conv_resample: if True, use learned convolutions for upsampling and
|
216 |
downsampling.
|
|
|
217 |
:param num_heads: the number of attention heads in each attention layer.
|
218 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
219 |
a fixed channel width per attention head.
|
|
|
229 |
self,
|
230 |
model_channels,
|
231 |
in_channels=1,
|
232 |
+
in_latent_channels=1024,
|
233 |
+
in_tokens=8193,
|
234 |
+
conditioning_dim_factor=8,
|
235 |
+
conditioning_expansion=4,
|
236 |
out_channels=2, # mean and variance
|
|
|
237 |
dropout=0,
|
238 |
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
239 |
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
240 |
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
241 |
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
242 |
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
243 |
+
token_conditioning_resolutions=(1,16,),
|
244 |
attention_resolutions=(512,1024,2048),
|
245 |
conv_resample=True,
|
|
|
246 |
use_fp16=False,
|
247 |
num_heads=1,
|
248 |
num_head_channels=-1,
|
249 |
num_heads_upsample=-1,
|
|
|
|
|
250 |
kernel_size=3,
|
251 |
scale_factor=2,
|
|
|
252 |
time_embed_dim_multiplier=4,
|
253 |
+
freeze_main_net=False,
|
254 |
+
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
255 |
+
use_scale_shift_norm=True,
|
256 |
+
# Parameters for regularization.
|
257 |
+
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
258 |
+
# Parameters for super-sampling.
|
259 |
+
super_sampling=False,
|
260 |
+
super_sampling_max_noising_factor=.1,
|
261 |
):
|
262 |
super().__init__()
|
263 |
|
264 |
if num_heads_upsample == -1:
|
265 |
num_heads_upsample = num_heads
|
266 |
|
267 |
+
if super_sampling:
|
268 |
+
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
269 |
self.in_channels = in_channels
|
270 |
self.model_channels = model_channels
|
271 |
self.out_channels = out_channels
|
|
|
273 |
self.dropout = dropout
|
274 |
self.channel_mult = channel_mult
|
275 |
self.conv_resample = conv_resample
|
|
|
276 |
self.num_heads = num_heads
|
277 |
self.num_head_channels = num_head_channels
|
278 |
self.num_heads_upsample = num_heads_upsample
|
279 |
+
self.super_sampling_enabled = super_sampling
|
280 |
+
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
281 |
+
self.unconditioned_percentage = unconditioned_percentage
|
282 |
+
self.enable_fp16 = use_fp16
|
283 |
+
self.alignment_size = 2 ** (len(channel_mult)+1)
|
284 |
+
self.freeze_main_net = freeze_main_net
|
285 |
padding = 1 if kernel_size == 3 else 2
|
286 |
+
down_kernel = 1 if efficient_convs else 3
|
287 |
|
288 |
time_embed_dim = model_channels * time_embed_dim_multiplier
|
289 |
self.time_embed = nn.Sequential(
|
290 |
+
Linear(model_channels, time_embed_dim),
|
291 |
nn.SiLU(),
|
292 |
+
Linear(time_embed_dim, time_embed_dim),
|
293 |
)
|
294 |
|
295 |
+
conditioning_dim = model_channels * conditioning_dim_factor
|
296 |
+
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
297 |
+
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
298 |
+
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
299 |
+
# transformer network.
|
300 |
+
self.code_converter = nn.Sequential(
|
301 |
+
nn.Embedding(in_tokens, conditioning_dim),
|
302 |
+
CheckpointedXTransformerEncoder(
|
303 |
+
needs_permute=False,
|
304 |
+
max_seq_len=-1,
|
305 |
+
use_pos_emb=False,
|
306 |
+
attn_layers=Encoder(
|
307 |
+
dim=conditioning_dim,
|
308 |
+
depth=3,
|
309 |
+
heads=num_heads,
|
310 |
+
ff_dropout=dropout,
|
311 |
+
attn_dropout=dropout,
|
312 |
+
use_rmsnorm=True,
|
313 |
+
ff_glu=True,
|
314 |
+
rotary_emb_dim=True,
|
315 |
+
)
|
316 |
+
))
|
317 |
+
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
|
318 |
+
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
|
319 |
+
if in_channels > 60: # It's a spectrogram.
|
320 |
+
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
|
321 |
+
CheckpointedXTransformerEncoder(
|
322 |
+
needs_permute=True,
|
323 |
+
max_seq_len=-1,
|
324 |
+
use_pos_emb=False,
|
325 |
+
attn_layers=Encoder(
|
326 |
+
dim=conditioning_dim,
|
327 |
+
depth=4,
|
328 |
+
heads=num_heads,
|
329 |
+
ff_dropout=dropout,
|
330 |
+
attn_dropout=dropout,
|
331 |
+
use_rmsnorm=True,
|
332 |
+
ff_glu=True,
|
333 |
+
rotary_emb_dim=True,
|
334 |
+
)
|
335 |
+
))
|
336 |
+
else:
|
337 |
+
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
|
338 |
+
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
|
339 |
+
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
340 |
+
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
341 |
+
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
342 |
+
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
343 |
+
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
344 |
+
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
345 |
+
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
346 |
+
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
347 |
)
|
348 |
+
self.conditioning_expansion = conditioning_expansion
|
349 |
+
|
350 |
+
self.input_blocks = nn.ModuleList(
|
351 |
+
[
|
352 |
+
TimestepEmbedSequential(
|
353 |
+
nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
|
354 |
+
)
|
355 |
+
]
|
356 |
+
)
|
357 |
+
token_conditioning_blocks = []
|
358 |
self._feature_size = model_channels
|
359 |
input_block_chans = [model_channels]
|
360 |
ch = model_channels
|
361 |
ds = 1
|
362 |
|
363 |
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
364 |
+
if ds in token_conditioning_resolutions:
|
365 |
+
token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
|
366 |
+
token_conditioning_block.weight.data *= .02
|
367 |
+
self.input_blocks.append(token_conditioning_block)
|
368 |
+
token_conditioning_blocks.append(token_conditioning_block)
|
369 |
|
370 |
for _ in range(num_blocks):
|
371 |
layers = [
|
372 |
+
ResBlock(
|
373 |
ch,
|
374 |
time_embed_dim,
|
375 |
dropout,
|
376 |
out_channels=int(mult * model_channels),
|
|
|
377 |
kernel_size=kernel_size,
|
378 |
+
efficient_config=efficient_convs,
|
379 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
380 |
)
|
381 |
]
|
382 |
ch = int(mult * model_channels)
|
|
|
388 |
num_head_channels=num_head_channels,
|
389 |
)
|
390 |
)
|
391 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
|
|
|
392 |
self._feature_size += ch
|
393 |
input_block_chans.append(ch)
|
394 |
if level != len(channel_mult) - 1:
|
395 |
out_ch = ch
|
396 |
+
self.input_blocks.append(
|
397 |
+
TimestepEmbedSequential(
|
398 |
+
Downsample(
|
399 |
+
ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
)
|
401 |
)
|
402 |
+
)
|
|
|
403 |
ch = out_ch
|
404 |
input_block_chans.append(ch)
|
405 |
ds *= 2
|
406 |
self._feature_size += ch
|
407 |
|
408 |
self.middle_block = TimestepEmbedSequential(
|
409 |
+
ResBlock(
|
410 |
ch,
|
411 |
time_embed_dim,
|
412 |
dropout,
|
|
|
413 |
kernel_size=kernel_size,
|
414 |
+
efficient_config=efficient_convs,
|
415 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
416 |
),
|
417 |
AttentionBlock(
|
418 |
ch,
|
419 |
num_heads=num_heads,
|
420 |
num_head_channels=num_head_channels,
|
421 |
),
|
422 |
+
ResBlock(
|
423 |
ch,
|
424 |
time_embed_dim,
|
425 |
dropout,
|
|
|
426 |
kernel_size=kernel_size,
|
427 |
+
efficient_config=efficient_convs,
|
428 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
429 |
),
|
430 |
)
|
431 |
self._feature_size += ch
|
|
|
435 |
for i in range(num_blocks + 1):
|
436 |
ich = input_block_chans.pop()
|
437 |
layers = [
|
438 |
+
ResBlock(
|
439 |
ch + ich,
|
440 |
time_embed_dim,
|
441 |
dropout,
|
442 |
out_channels=int(model_channels * mult),
|
|
|
443 |
kernel_size=kernel_size,
|
444 |
+
efficient_config=efficient_convs,
|
445 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
446 |
)
|
447 |
]
|
448 |
ch = int(model_channels * mult)
|
|
|
457 |
if level and i == num_blocks:
|
458 |
out_ch = ch
|
459 |
layers.append(
|
460 |
+
Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
)
|
462 |
ds //= 2
|
463 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
|
|
|
|
464 |
self._feature_size += ch
|
465 |
|
466 |
self.out = nn.Sequential(
|
|
|
469 |
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
|
470 |
)
|
471 |
|
472 |
+
def fix_alignment(self, x, aligned_conditioning):
|
473 |
+
"""
|
474 |
+
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
|
475 |
+
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
|
476 |
+
"""
|
477 |
+
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
478 |
+
if cm != 0:
|
479 |
+
pc = (cm-x.shape[-1])/x.shape[-1]
|
480 |
+
x = F.pad(x, (0,cm-x.shape[-1]))
|
481 |
+
# Also fix aligned_latent, which is aligned to x.
|
482 |
+
if is_latent(aligned_conditioning):
|
483 |
+
aligned_conditioning = torch.cat([aligned_conditioning,
|
484 |
+
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
485 |
+
else:
|
486 |
+
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
487 |
+
return x, aligned_conditioning
|
488 |
+
|
489 |
+
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
490 |
"""
|
491 |
Apply the model to an input batch.
|
492 |
|
493 |
:param x: an [N x C x ...] Tensor of inputs.
|
494 |
:param timesteps: a 1-D batch of timesteps.
|
495 |
+
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
496 |
+
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
497 |
+
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
|
498 |
+
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
499 |
:return: an [N x C x ...] Tensor of outputs.
|
500 |
"""
|
501 |
+
assert conditioning_input is not None
|
502 |
+
if self.super_sampling_enabled:
|
503 |
+
assert lr_input is not None
|
504 |
+
if self.training and self.super_sampling_max_noising_factor > 0:
|
505 |
+
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
506 |
+
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
|
507 |
+
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
508 |
+
x = torch.cat([x, lr_input], dim=1)
|
509 |
+
|
510 |
+
# Shuffle aligned_latent to BxCxS format
|
511 |
+
if is_latent(aligned_conditioning):
|
512 |
+
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
513 |
+
|
514 |
+
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
515 |
+
orig_x_shape = x.shape[-1]
|
516 |
+
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
517 |
+
|
518 |
+
with autocast(x.device.type, enabled=self.enable_fp16):
|
519 |
+
hs = []
|
520 |
+
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
521 |
+
|
522 |
+
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
523 |
+
if conditioning_free:
|
524 |
+
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
525 |
else:
|
526 |
+
cond_emb = self.contextual_embedder(conditioning_input)
|
527 |
+
if len(cond_emb.shape) == 3: # Just take the first element.
|
528 |
+
cond_emb = cond_emb[:, :, 0]
|
529 |
+
if is_latent(aligned_conditioning):
|
530 |
+
code_emb = self.latent_converter(aligned_conditioning)
|
531 |
+
else:
|
532 |
+
code_emb = self.code_converter(aligned_conditioning)
|
533 |
+
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
534 |
+
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
535 |
+
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
536 |
+
if self.training and self.unconditioned_percentage > 0:
|
537 |
+
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
538 |
+
device=code_emb.device) < self.unconditioned_percentage
|
539 |
+
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
540 |
+
code_emb)
|
541 |
+
|
542 |
+
# Everything after this comment is timestep dependent.
|
543 |
+
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
544 |
+
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
545 |
+
|
546 |
+
first = True
|
547 |
+
time_emb = time_emb.float()
|
548 |
+
h = x
|
549 |
+
for k, module in enumerate(self.input_blocks):
|
550 |
+
if isinstance(module, nn.Conv1d):
|
551 |
+
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
552 |
+
h = h + h_tok
|
553 |
+
else:
|
554 |
+
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
555 |
+
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
556 |
+
h = module(h, time_emb)
|
557 |
+
hs.append(h)
|
558 |
+
first = False
|
559 |
+
h = self.middle_block(h, time_emb)
|
560 |
+
for module in self.output_blocks:
|
561 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
562 |
+
h = module(h, time_emb)
|
563 |
+
|
564 |
+
# Last block also has autocast disabled for high-precision outputs.
|
565 |
+
h = h.float()
|
566 |
+
out = self.out(h)
|
567 |
+
|
568 |
+
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
569 |
+
extraneous_addition = 0
|
570 |
+
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
571 |
+
for p in params:
|
572 |
+
extraneous_addition = extraneous_addition + p.mean()
|
573 |
+
out = out + extraneous_addition * 0
|
574 |
+
|
575 |
+
return out[:, :, :orig_x_shape]
|
576 |
|
577 |
|
|
|
578 |
if __name__ == '__main__':
|
579 |
+
clip = torch.randn(2, 1, 32868)
|
580 |
+
aligned_latent = torch.randn(2,388,1024)
|
581 |
+
aligned_sequence = torch.randint(0,8192,(2,388))
|
582 |
+
cond = torch.randn(2, 1, 44000)
|
583 |
+
ts = torch.LongTensor([600, 600])
|
584 |
+
model = DiffusionTts(128,
|
585 |
+
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
586 |
+
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
587 |
+
token_conditioning_resolutions=[1,4,16,64],
|
588 |
+
attention_resolutions=[],
|
589 |
+
num_heads=8,
|
590 |
+
kernel_size=3,
|
591 |
+
scale_factor=2,
|
592 |
+
time_embed_dim_multiplier=4,
|
593 |
+
super_sampling=False,
|
594 |
+
efficient_convs=False)
|
595 |
+
# Test with latent aligned conditioning
|
596 |
+
o = model(clip, ts, aligned_latent, cond)
|
597 |
+
# Test with sequence aligned conditioning
|
598 |
+
o = model(clip, ts, aligned_sequence, cond)
|
models/dvae.py
DELETED
@@ -1,390 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
from math import sqrt
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.distributed as distributed
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from einops import rearrange
|
9 |
-
|
10 |
-
|
11 |
-
def default(val, d):
|
12 |
-
return val if val is not None else d
|
13 |
-
|
14 |
-
|
15 |
-
def eval_decorator(fn):
|
16 |
-
def inner(model, *args, **kwargs):
|
17 |
-
was_training = model.training
|
18 |
-
model.eval()
|
19 |
-
out = fn(model, *args, **kwargs)
|
20 |
-
model.train(was_training)
|
21 |
-
return out
|
22 |
-
return inner
|
23 |
-
|
24 |
-
|
25 |
-
# Quantizer implemented by the rosinality vqvae repo.
|
26 |
-
# Credit: https://github.com/rosinality/vq-vae-2-pytorch
|
27 |
-
class Quantize(nn.Module):
|
28 |
-
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
self.dim = dim
|
32 |
-
self.n_embed = n_embed
|
33 |
-
self.decay = decay
|
34 |
-
self.eps = eps
|
35 |
-
|
36 |
-
self.balancing_heuristic = balancing_heuristic
|
37 |
-
self.codes = None
|
38 |
-
self.max_codes = 64000
|
39 |
-
self.codes_full = False
|
40 |
-
self.new_return_order = new_return_order
|
41 |
-
|
42 |
-
embed = torch.randn(dim, n_embed)
|
43 |
-
self.register_buffer("embed", embed)
|
44 |
-
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
45 |
-
self.register_buffer("embed_avg", embed.clone())
|
46 |
-
|
47 |
-
def forward(self, input, return_soft_codes=False):
|
48 |
-
if self.balancing_heuristic and self.codes_full:
|
49 |
-
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
50 |
-
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
|
51 |
-
ep = self.embed.permute(1,0)
|
52 |
-
ea = self.embed_avg.permute(1,0)
|
53 |
-
rand_embed = torch.randn_like(ep) * mask
|
54 |
-
self.embed = (ep * ~mask + rand_embed).permute(1,0)
|
55 |
-
self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
|
56 |
-
self.cluster_size = self.cluster_size * ~mask.squeeze()
|
57 |
-
if torch.any(mask):
|
58 |
-
print(f"Reset {torch.sum(mask)} embedding codes.")
|
59 |
-
self.codes = None
|
60 |
-
self.codes_full = False
|
61 |
-
|
62 |
-
flatten = input.reshape(-1, self.dim)
|
63 |
-
dist = (
|
64 |
-
flatten.pow(2).sum(1, keepdim=True)
|
65 |
-
- 2 * flatten @ self.embed
|
66 |
-
+ self.embed.pow(2).sum(0, keepdim=True)
|
67 |
-
)
|
68 |
-
soft_codes = -dist
|
69 |
-
_, embed_ind = soft_codes.max(1)
|
70 |
-
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
71 |
-
embed_ind = embed_ind.view(*input.shape[:-1])
|
72 |
-
quantize = self.embed_code(embed_ind)
|
73 |
-
|
74 |
-
if self.balancing_heuristic:
|
75 |
-
if self.codes is None:
|
76 |
-
self.codes = embed_ind.flatten()
|
77 |
-
else:
|
78 |
-
self.codes = torch.cat([self.codes, embed_ind.flatten()])
|
79 |
-
if len(self.codes) > self.max_codes:
|
80 |
-
self.codes = self.codes[-self.max_codes:]
|
81 |
-
self.codes_full = True
|
82 |
-
|
83 |
-
if self.training:
|
84 |
-
embed_onehot_sum = embed_onehot.sum(0)
|
85 |
-
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
86 |
-
|
87 |
-
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
88 |
-
distributed.all_reduce(embed_onehot_sum)
|
89 |
-
distributed.all_reduce(embed_sum)
|
90 |
-
|
91 |
-
self.cluster_size.data.mul_(self.decay).add_(
|
92 |
-
embed_onehot_sum, alpha=1 - self.decay
|
93 |
-
)
|
94 |
-
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
95 |
-
n = self.cluster_size.sum()
|
96 |
-
cluster_size = (
|
97 |
-
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
98 |
-
)
|
99 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
100 |
-
self.embed.data.copy_(embed_normalized)
|
101 |
-
|
102 |
-
diff = (quantize.detach() - input).pow(2).mean()
|
103 |
-
quantize = input + (quantize - input).detach()
|
104 |
-
|
105 |
-
if return_soft_codes:
|
106 |
-
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
107 |
-
elif self.new_return_order:
|
108 |
-
return quantize, embed_ind, diff
|
109 |
-
else:
|
110 |
-
return quantize, diff, embed_ind
|
111 |
-
|
112 |
-
def embed_code(self, embed_id):
|
113 |
-
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
114 |
-
|
115 |
-
|
116 |
-
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
|
117 |
-
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
|
118 |
-
# values with the specified expected variance.
|
119 |
-
class DiscretizationLoss(nn.Module):
|
120 |
-
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
|
121 |
-
super().__init__()
|
122 |
-
self.discrete_bins = discrete_bins
|
123 |
-
self.dim = dim
|
124 |
-
self.dist = torch.distributions.Normal(0, scale=expected_variance)
|
125 |
-
if store_past > 0:
|
126 |
-
self.record_past = True
|
127 |
-
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
|
128 |
-
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
|
129 |
-
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
|
130 |
-
else:
|
131 |
-
self.record_past = False
|
132 |
-
|
133 |
-
def forward(self, x):
|
134 |
-
other_dims = set(range(len(x.shape)))-set([self.dim])
|
135 |
-
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
|
136 |
-
averaged = averaged - averaged.mean()
|
137 |
-
|
138 |
-
if self.record_past:
|
139 |
-
acc_count = self.accumulator.shape[0]
|
140 |
-
avg = averaged.detach().clone()
|
141 |
-
if self.accumulator_filled > 0:
|
142 |
-
averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
|
143 |
-
averaged / acc_count
|
144 |
-
|
145 |
-
# Also push averaged into the accumulator.
|
146 |
-
self.accumulator[self.accumulator_index] = avg
|
147 |
-
self.accumulator_index += 1
|
148 |
-
if self.accumulator_index >= acc_count:
|
149 |
-
self.accumulator_index *= 0
|
150 |
-
if self.accumulator_filled <= 0:
|
151 |
-
self.accumulator_filled += 1
|
152 |
-
|
153 |
-
return torch.sum(-self.dist.log_prob(averaged))
|
154 |
-
|
155 |
-
|
156 |
-
class ResBlock(nn.Module):
|
157 |
-
def __init__(self, chan, conv, activation):
|
158 |
-
super().__init__()
|
159 |
-
self.net = nn.Sequential(
|
160 |
-
conv(chan, chan, 3, padding = 1),
|
161 |
-
activation(),
|
162 |
-
conv(chan, chan, 3, padding = 1),
|
163 |
-
activation(),
|
164 |
-
conv(chan, chan, 1)
|
165 |
-
)
|
166 |
-
|
167 |
-
def forward(self, x):
|
168 |
-
return self.net(x) + x
|
169 |
-
|
170 |
-
|
171 |
-
class UpsampledConv(nn.Module):
|
172 |
-
def __init__(self, conv, *args, **kwargs):
|
173 |
-
super().__init__()
|
174 |
-
assert 'stride' in kwargs.keys()
|
175 |
-
self.stride = kwargs['stride']
|
176 |
-
del kwargs['stride']
|
177 |
-
self.conv = conv(*args, **kwargs)
|
178 |
-
|
179 |
-
def forward(self, x):
|
180 |
-
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
|
181 |
-
return self.conv(up)
|
182 |
-
|
183 |
-
|
184 |
-
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
185 |
-
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
186 |
-
class DiscreteVAE(nn.Module):
|
187 |
-
def __init__(
|
188 |
-
self,
|
189 |
-
positional_dims=2,
|
190 |
-
num_tokens = 512,
|
191 |
-
codebook_dim = 512,
|
192 |
-
num_layers = 3,
|
193 |
-
num_resnet_blocks = 0,
|
194 |
-
hidden_dim = 64,
|
195 |
-
channels = 3,
|
196 |
-
stride = 2,
|
197 |
-
kernel_size = 4,
|
198 |
-
use_transposed_convs = True,
|
199 |
-
encoder_norm = False,
|
200 |
-
activation = 'relu',
|
201 |
-
smooth_l1_loss = False,
|
202 |
-
straight_through = False,
|
203 |
-
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
204 |
-
record_codes = False,
|
205 |
-
discretization_loss_averaging_steps = 100,
|
206 |
-
lr_quantizer_args = {},
|
207 |
-
):
|
208 |
-
super().__init__()
|
209 |
-
has_resblocks = num_resnet_blocks > 0
|
210 |
-
|
211 |
-
self.num_tokens = num_tokens
|
212 |
-
self.num_layers = num_layers
|
213 |
-
self.straight_through = straight_through
|
214 |
-
self.positional_dims = positional_dims
|
215 |
-
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
216 |
-
|
217 |
-
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
218 |
-
if positional_dims == 2:
|
219 |
-
conv = nn.Conv2d
|
220 |
-
conv_transpose = nn.ConvTranspose2d
|
221 |
-
else:
|
222 |
-
conv = nn.Conv1d
|
223 |
-
conv_transpose = nn.ConvTranspose1d
|
224 |
-
if not use_transposed_convs:
|
225 |
-
conv_transpose = functools.partial(UpsampledConv, conv)
|
226 |
-
|
227 |
-
if activation == 'relu':
|
228 |
-
act = nn.ReLU
|
229 |
-
elif activation == 'silu':
|
230 |
-
act = nn.SiLU
|
231 |
-
else:
|
232 |
-
assert NotImplementedError()
|
233 |
-
|
234 |
-
|
235 |
-
enc_layers = []
|
236 |
-
dec_layers = []
|
237 |
-
|
238 |
-
if num_layers > 0:
|
239 |
-
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
240 |
-
dec_chans = list(reversed(enc_chans))
|
241 |
-
|
242 |
-
enc_chans = [channels, *enc_chans]
|
243 |
-
|
244 |
-
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
245 |
-
dec_chans = [dec_init_chan, *dec_chans]
|
246 |
-
|
247 |
-
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
248 |
-
|
249 |
-
pad = (kernel_size - 1) // 2
|
250 |
-
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
251 |
-
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
|
252 |
-
if encoder_norm:
|
253 |
-
enc_layers.append(nn.GroupNorm(8, enc_out))
|
254 |
-
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
|
255 |
-
dec_out_chans = dec_chans[-1]
|
256 |
-
innermost_dim = dec_chans[0]
|
257 |
-
else:
|
258 |
-
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
|
259 |
-
dec_out_chans = hidden_dim
|
260 |
-
innermost_dim = hidden_dim
|
261 |
-
|
262 |
-
for _ in range(num_resnet_blocks):
|
263 |
-
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
264 |
-
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
265 |
-
|
266 |
-
if num_resnet_blocks > 0:
|
267 |
-
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
268 |
-
|
269 |
-
|
270 |
-
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
271 |
-
dec_layers.append(conv(dec_out_chans, channels, 1))
|
272 |
-
|
273 |
-
self.encoder = nn.Sequential(*enc_layers)
|
274 |
-
self.decoder = nn.Sequential(*dec_layers)
|
275 |
-
|
276 |
-
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
277 |
-
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
278 |
-
|
279 |
-
# take care of normalization within class
|
280 |
-
self.normalization = normalization
|
281 |
-
self.record_codes = record_codes
|
282 |
-
if record_codes:
|
283 |
-
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
284 |
-
self.code_ind = 0
|
285 |
-
self.total_codes = 0
|
286 |
-
self.internal_step = 0
|
287 |
-
|
288 |
-
def norm(self, images):
|
289 |
-
if not self.normalization is not None:
|
290 |
-
return images
|
291 |
-
|
292 |
-
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
293 |
-
arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
|
294 |
-
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
295 |
-
images = images.clone()
|
296 |
-
images.sub_(means).div_(stds)
|
297 |
-
return images
|
298 |
-
|
299 |
-
def get_debug_values(self, step, __):
|
300 |
-
if self.record_codes and self.total_codes > 0:
|
301 |
-
# Report annealing schedule
|
302 |
-
return {'histogram_codes': self.codes[:self.total_codes]}
|
303 |
-
else:
|
304 |
-
return {}
|
305 |
-
|
306 |
-
@torch.no_grad()
|
307 |
-
@eval_decorator
|
308 |
-
def get_codebook_indices(self, images):
|
309 |
-
img = self.norm(images)
|
310 |
-
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
311 |
-
sampled, codes, _ = self.codebook(logits)
|
312 |
-
self.log_codes(codes)
|
313 |
-
return codes
|
314 |
-
|
315 |
-
def decode(
|
316 |
-
self,
|
317 |
-
img_seq
|
318 |
-
):
|
319 |
-
self.log_codes(img_seq)
|
320 |
-
if hasattr(self.codebook, 'embed_code'):
|
321 |
-
image_embeds = self.codebook.embed_code(img_seq)
|
322 |
-
else:
|
323 |
-
image_embeds = F.embedding(img_seq, self.codebook.codebook)
|
324 |
-
b, n, d = image_embeds.shape
|
325 |
-
|
326 |
-
kwargs = {}
|
327 |
-
if self.positional_dims == 1:
|
328 |
-
arrange = 'b n d -> b d n'
|
329 |
-
else:
|
330 |
-
h = w = int(sqrt(n))
|
331 |
-
arrange = 'b (h w) d -> b d h w'
|
332 |
-
kwargs = {'h': h, 'w': w}
|
333 |
-
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
334 |
-
images = [image_embeds]
|
335 |
-
for layer in self.decoder:
|
336 |
-
images.append(layer(images[-1]))
|
337 |
-
return images[-1], images[-2]
|
338 |
-
|
339 |
-
def infer(self, img):
|
340 |
-
img = self.norm(img)
|
341 |
-
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
342 |
-
sampled, codes, commitment_loss = self.codebook(logits)
|
343 |
-
return self.decode(codes)
|
344 |
-
|
345 |
-
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
346 |
-
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
347 |
-
# more lossy (but useful for determining network performance).
|
348 |
-
def forward(
|
349 |
-
self,
|
350 |
-
img
|
351 |
-
):
|
352 |
-
img = self.norm(img)
|
353 |
-
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
354 |
-
sampled, codes, commitment_loss = self.codebook(logits)
|
355 |
-
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
356 |
-
|
357 |
-
if self.training:
|
358 |
-
out = sampled
|
359 |
-
for d in self.decoder:
|
360 |
-
out = d(out)
|
361 |
-
self.log_codes(codes)
|
362 |
-
else:
|
363 |
-
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
364 |
-
out, _ = self.decode(codes)
|
365 |
-
|
366 |
-
# reconstruction loss
|
367 |
-
recon_loss = self.loss_fn(img, out, reduction='none')
|
368 |
-
|
369 |
-
return recon_loss, commitment_loss, out
|
370 |
-
|
371 |
-
def log_codes(self, codes):
|
372 |
-
# This is so we can debug the distribution of codes being learned.
|
373 |
-
if self.record_codes and self.internal_step % 10 == 0:
|
374 |
-
codes = codes.flatten()
|
375 |
-
l = codes.shape[0]
|
376 |
-
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
377 |
-
self.codes[i:i+l] = codes.cpu()
|
378 |
-
self.code_ind = self.code_ind + l
|
379 |
-
if self.code_ind >= self.codes.shape[0]:
|
380 |
-
self.code_ind = 0
|
381 |
-
self.total_codes += 1
|
382 |
-
self.internal_step += 1
|
383 |
-
|
384 |
-
|
385 |
-
if __name__ == '__main__':
|
386 |
-
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
387 |
-
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False)
|
388 |
-
r,l,o=v(torch.randn(1,80,256))
|
389 |
-
v.decode(torch.randint(0,8192,(1,256)))
|
390 |
-
print(o.shape, l.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/vocoder.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
class KernelPredictor(torch.nn.Module):
|
8 |
+
''' Kernel predictor for the location-variable convolutions'''
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
cond_channels,
|
13 |
+
conv_in_channels,
|
14 |
+
conv_out_channels,
|
15 |
+
conv_layers,
|
16 |
+
conv_kernel_size=3,
|
17 |
+
kpnet_hidden_channels=64,
|
18 |
+
kpnet_conv_size=3,
|
19 |
+
kpnet_dropout=0.0,
|
20 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
21 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
22 |
+
):
|
23 |
+
'''
|
24 |
+
Args:
|
25 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
26 |
+
conv_in_channels (int): number of channel for the input sequence,
|
27 |
+
conv_out_channels (int): number of channel for the output sequence,
|
28 |
+
conv_layers (int): number of layers
|
29 |
+
'''
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.conv_in_channels = conv_in_channels
|
33 |
+
self.conv_out_channels = conv_out_channels
|
34 |
+
self.conv_kernel_size = conv_kernel_size
|
35 |
+
self.conv_layers = conv_layers
|
36 |
+
|
37 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
38 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
39 |
+
|
40 |
+
self.input_conv = nn.Sequential(
|
41 |
+
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
42 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
43 |
+
)
|
44 |
+
|
45 |
+
self.residual_convs = nn.ModuleList()
|
46 |
+
padding = (kpnet_conv_size - 1) // 2
|
47 |
+
for _ in range(3):
|
48 |
+
self.residual_convs.append(
|
49 |
+
nn.Sequential(
|
50 |
+
nn.Dropout(kpnet_dropout),
|
51 |
+
nn.utils.weight_norm(
|
52 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
53 |
+
bias=True)),
|
54 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
55 |
+
nn.utils.weight_norm(
|
56 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
57 |
+
bias=True)),
|
58 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.kernel_conv = nn.utils.weight_norm(
|
62 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
|
63 |
+
self.bias_conv = nn.utils.weight_norm(
|
64 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
|
65 |
+
|
66 |
+
def forward(self, c):
|
67 |
+
'''
|
68 |
+
Args:
|
69 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
70 |
+
'''
|
71 |
+
batch, _, cond_length = c.shape
|
72 |
+
c = self.input_conv(c)
|
73 |
+
for residual_conv in self.residual_convs:
|
74 |
+
residual_conv.to(c.device)
|
75 |
+
c = c + residual_conv(c)
|
76 |
+
k = self.kernel_conv(c)
|
77 |
+
b = self.bias_conv(c)
|
78 |
+
kernels = k.contiguous().view(
|
79 |
+
batch,
|
80 |
+
self.conv_layers,
|
81 |
+
self.conv_in_channels,
|
82 |
+
self.conv_out_channels,
|
83 |
+
self.conv_kernel_size,
|
84 |
+
cond_length,
|
85 |
+
)
|
86 |
+
bias = b.contiguous().view(
|
87 |
+
batch,
|
88 |
+
self.conv_layers,
|
89 |
+
self.conv_out_channels,
|
90 |
+
cond_length,
|
91 |
+
)
|
92 |
+
|
93 |
+
return kernels, bias
|
94 |
+
|
95 |
+
def remove_weight_norm(self):
|
96 |
+
nn.utils.remove_weight_norm(self.input_conv[0])
|
97 |
+
nn.utils.remove_weight_norm(self.kernel_conv)
|
98 |
+
nn.utils.remove_weight_norm(self.bias_conv)
|
99 |
+
for block in self.residual_convs:
|
100 |
+
nn.utils.remove_weight_norm(block[1])
|
101 |
+
nn.utils.remove_weight_norm(block[3])
|
102 |
+
|
103 |
+
|
104 |
+
class LVCBlock(torch.nn.Module):
|
105 |
+
'''the location-variable convolutions'''
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
in_channels,
|
110 |
+
cond_channels,
|
111 |
+
stride,
|
112 |
+
dilations=[1, 3, 9, 27],
|
113 |
+
lReLU_slope=0.2,
|
114 |
+
conv_kernel_size=3,
|
115 |
+
cond_hop_length=256,
|
116 |
+
kpnet_hidden_channels=64,
|
117 |
+
kpnet_conv_size=3,
|
118 |
+
kpnet_dropout=0.0,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.cond_hop_length = cond_hop_length
|
123 |
+
self.conv_layers = len(dilations)
|
124 |
+
self.conv_kernel_size = conv_kernel_size
|
125 |
+
|
126 |
+
self.kernel_predictor = KernelPredictor(
|
127 |
+
cond_channels=cond_channels,
|
128 |
+
conv_in_channels=in_channels,
|
129 |
+
conv_out_channels=2 * in_channels,
|
130 |
+
conv_layers=len(dilations),
|
131 |
+
conv_kernel_size=conv_kernel_size,
|
132 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
133 |
+
kpnet_conv_size=kpnet_conv_size,
|
134 |
+
kpnet_dropout=kpnet_dropout,
|
135 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
|
136 |
+
)
|
137 |
+
|
138 |
+
self.convt_pre = nn.Sequential(
|
139 |
+
nn.LeakyReLU(lReLU_slope),
|
140 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
|
141 |
+
padding=stride // 2 + stride % 2, output_padding=stride % 2)),
|
142 |
+
)
|
143 |
+
|
144 |
+
self.conv_blocks = nn.ModuleList()
|
145 |
+
for dilation in dilations:
|
146 |
+
self.conv_blocks.append(
|
147 |
+
nn.Sequential(
|
148 |
+
nn.LeakyReLU(lReLU_slope),
|
149 |
+
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
|
150 |
+
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
|
151 |
+
nn.LeakyReLU(lReLU_slope),
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x, c):
|
156 |
+
''' forward propagation of the location-variable convolutions.
|
157 |
+
Args:
|
158 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
159 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
163 |
+
'''
|
164 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
165 |
+
|
166 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
167 |
+
kernels, bias = self.kernel_predictor(c)
|
168 |
+
|
169 |
+
for i, conv in enumerate(self.conv_blocks):
|
170 |
+
output = conv(x) # (B, c_g, stride * L')
|
171 |
+
|
172 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
173 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
174 |
+
|
175 |
+
output = self.location_variable_convolution(output, k, b,
|
176 |
+
hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
|
177 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
178 |
+
output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
183 |
+
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
184 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
185 |
+
Args:
|
186 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
187 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
188 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
189 |
+
dilation (int): the dilation of convolution.
|
190 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
191 |
+
Returns:
|
192 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
193 |
+
'''
|
194 |
+
batch, _, in_length = x.shape
|
195 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
196 |
+
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
197 |
+
|
198 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
199 |
+
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
200 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
201 |
+
|
202 |
+
if hop_size < dilation:
|
203 |
+
x = F.pad(x, (0, dilation), 'constant', 0)
|
204 |
+
x = x.unfold(3, dilation,
|
205 |
+
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
206 |
+
x = x[:, :, :, :, :hop_size]
|
207 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
208 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
209 |
+
|
210 |
+
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
211 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
212 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
213 |
+
o = o + bias
|
214 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
215 |
+
|
216 |
+
return o
|
217 |
+
|
218 |
+
def remove_weight_norm(self):
|
219 |
+
self.kernel_predictor.remove_weight_norm()
|
220 |
+
nn.utils.remove_weight_norm(self.convt_pre[1])
|
221 |
+
for block in self.conv_blocks:
|
222 |
+
nn.utils.remove_weight_norm(block[1])
|
223 |
+
|
224 |
+
|
225 |
+
class UnivNetGenerator(nn.Module):
|
226 |
+
"""UnivNet Generator"""
|
227 |
+
|
228 |
+
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
|
229 |
+
# Below are MEL configurations options that this generator requires.
|
230 |
+
hop_length=256, n_mel_channels=100):
|
231 |
+
super(UnivNetGenerator, self).__init__()
|
232 |
+
self.mel_channel = n_mel_channels
|
233 |
+
self.noise_dim = noise_dim
|
234 |
+
self.hop_length = hop_length
|
235 |
+
channel_size = channel_size
|
236 |
+
kpnet_conv_size = kpnet_conv_size
|
237 |
+
|
238 |
+
self.res_stack = nn.ModuleList()
|
239 |
+
hop_length = 1
|
240 |
+
for stride in strides:
|
241 |
+
hop_length = stride * hop_length
|
242 |
+
self.res_stack.append(
|
243 |
+
LVCBlock(
|
244 |
+
channel_size,
|
245 |
+
n_mel_channels,
|
246 |
+
stride=stride,
|
247 |
+
dilations=dilations,
|
248 |
+
lReLU_slope=lReLU_slope,
|
249 |
+
cond_hop_length=hop_length,
|
250 |
+
kpnet_conv_size=kpnet_conv_size
|
251 |
+
)
|
252 |
+
)
|
253 |
+
|
254 |
+
self.conv_pre = \
|
255 |
+
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
256 |
+
|
257 |
+
self.conv_post = nn.Sequential(
|
258 |
+
nn.LeakyReLU(lReLU_slope),
|
259 |
+
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
260 |
+
nn.Tanh(),
|
261 |
+
)
|
262 |
+
|
263 |
+
def forward(self, c, z):
|
264 |
+
'''
|
265 |
+
Args:
|
266 |
+
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
267 |
+
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
268 |
+
|
269 |
+
'''
|
270 |
+
z = self.conv_pre(z) # (B, c_g, L)
|
271 |
+
|
272 |
+
for res_block in self.res_stack:
|
273 |
+
res_block.to(z.device)
|
274 |
+
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
275 |
+
|
276 |
+
z = self.conv_post(z) # (B, 1, L * 256)
|
277 |
+
|
278 |
+
return z
|
279 |
+
|
280 |
+
def eval(self, inference=False):
|
281 |
+
super(UnivNetGenerator, self).eval()
|
282 |
+
# don't remove weight norm while validation in training loop
|
283 |
+
if inference:
|
284 |
+
self.remove_weight_norm()
|
285 |
+
|
286 |
+
def remove_weight_norm(self):
|
287 |
+
print('Removing weight norm...')
|
288 |
+
|
289 |
+
nn.utils.remove_weight_norm(self.conv_pre)
|
290 |
+
|
291 |
+
for layer in self.conv_post:
|
292 |
+
if len(layer.state_dict()) != 0:
|
293 |
+
nn.utils.remove_weight_norm(layer)
|
294 |
+
|
295 |
+
for res_block in self.res_stack:
|
296 |
+
res_block.remove_weight_norm()
|
297 |
+
|
298 |
+
def inference(self, c, z=None):
|
299 |
+
# pad input mel with zeros to cut artifact
|
300 |
+
# see https://github.com/seungwonpark/melgan/issues/8
|
301 |
+
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
302 |
+
mel = torch.cat((c, zero), dim=2)
|
303 |
+
|
304 |
+
if z is None:
|
305 |
+
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
306 |
+
|
307 |
+
audio = self.forward(mel, z)
|
308 |
+
audio = audio[:, :, :-(self.hop_length * 10)]
|
309 |
+
audio = audio.clamp(min=-1, max=1)
|
310 |
+
return audio
|
311 |
+
|
312 |
+
|
313 |
+
if __name__ == '__main__':
|
314 |
+
model = UnivNetGenerator()
|
315 |
+
|
316 |
+
c = torch.randn(3, 100, 10)
|
317 |
+
z = torch.randn(3, 64, 10)
|
318 |
+
print(c.shape)
|
319 |
+
|
320 |
+
y = model(c, z)
|
321 |
+
print(y.shape)
|
322 |
+
assert y.shape == torch.Size([3, 1, 2560])
|
323 |
+
|
324 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
325 |
+
print(pytorch_total_params)
|
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ tokenizers
|
|
6 |
inflect
|
7 |
progressbar
|
8 |
einops
|
9 |
-
unidecode
|
|
|
|
6 |
inflect
|
7 |
progressbar
|
8 |
einops
|
9 |
+
unidecode
|
10 |
+
x-transformers
|
utils/audio.py
CHANGED
@@ -3,6 +3,8 @@ import torchaudio
|
|
3 |
import numpy as np
|
4 |
from scipy.io.wavfile import read
|
5 |
|
|
|
|
|
6 |
|
7 |
def load_wav_to_torch(full_path):
|
8 |
sampling_rate, data = read(full_path)
|
@@ -43,4 +45,86 @@ def load_audio(audiopath, sampling_rate):
|
|
43 |
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
44 |
audio.clip_(-1, 1)
|
45 |
|
46 |
-
return audio.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
from scipy.io.wavfile import read
|
5 |
|
6 |
+
from utils.stft import STFT
|
7 |
+
|
8 |
|
9 |
def load_wav_to_torch(full_path):
|
10 |
sampling_rate, data = read(full_path)
|
|
|
45 |
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
46 |
audio.clip_(-1, 1)
|
47 |
|
48 |
+
return audio.unsqueeze(0)
|
49 |
+
|
50 |
+
|
51 |
+
TACOTRON_MEL_MAX = 2.3143386840820312
|
52 |
+
TACOTRON_MEL_MIN = -11.512925148010254
|
53 |
+
|
54 |
+
|
55 |
+
def denormalize_tacotron_mel(norm_mel):
|
56 |
+
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
57 |
+
|
58 |
+
|
59 |
+
def normalize_tacotron_mel(mel):
|
60 |
+
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
61 |
+
|
62 |
+
|
63 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
64 |
+
"""
|
65 |
+
PARAMS
|
66 |
+
------
|
67 |
+
C: compression factor
|
68 |
+
"""
|
69 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
70 |
+
|
71 |
+
|
72 |
+
def dynamic_range_decompression(x, C=1):
|
73 |
+
"""
|
74 |
+
PARAMS
|
75 |
+
------
|
76 |
+
C: compression factor used to compress
|
77 |
+
"""
|
78 |
+
return torch.exp(x) / C
|
79 |
+
|
80 |
+
|
81 |
+
class TacotronSTFT(torch.nn.Module):
|
82 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
83 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
84 |
+
mel_fmax=8000.0):
|
85 |
+
super(TacotronSTFT, self).__init__()
|
86 |
+
self.n_mel_channels = n_mel_channels
|
87 |
+
self.sampling_rate = sampling_rate
|
88 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
89 |
+
from librosa.filters import mel as librosa_mel_fn
|
90 |
+
mel_basis = librosa_mel_fn(
|
91 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
92 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
93 |
+
self.register_buffer('mel_basis', mel_basis)
|
94 |
+
|
95 |
+
def spectral_normalize(self, magnitudes):
|
96 |
+
output = dynamic_range_compression(magnitudes)
|
97 |
+
return output
|
98 |
+
|
99 |
+
def spectral_de_normalize(self, magnitudes):
|
100 |
+
output = dynamic_range_decompression(magnitudes)
|
101 |
+
return output
|
102 |
+
|
103 |
+
def mel_spectrogram(self, y):
|
104 |
+
"""Computes mel-spectrograms from a batch of waves
|
105 |
+
PARAMS
|
106 |
+
------
|
107 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
108 |
+
|
109 |
+
RETURNS
|
110 |
+
-------
|
111 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
112 |
+
"""
|
113 |
+
assert(torch.min(y.data) >= -10)
|
114 |
+
assert(torch.max(y.data) <= 10)
|
115 |
+
y = torch.clip(y, min=-1, max=1)
|
116 |
+
|
117 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
118 |
+
magnitudes = magnitudes.data
|
119 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
120 |
+
mel_output = self.spectral_normalize(mel_output)
|
121 |
+
return mel_output
|
122 |
+
|
123 |
+
|
124 |
+
def wav_to_univnet_mel(wav, do_normalization=False):
|
125 |
+
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
126 |
+
stft = stft.cuda()
|
127 |
+
mel = stft.mel_spectrogram(wav)
|
128 |
+
if do_normalization:
|
129 |
+
mel = normalize_tacotron_mel(mel)
|
130 |
+
return mel
|
utils/diffusion.py
CHANGED
@@ -197,11 +197,17 @@ class GaussianDiffusion:
|
|
197 |
model_var_type,
|
198 |
loss_type,
|
199 |
rescale_timesteps=False,
|
|
|
|
|
|
|
200 |
):
|
201 |
self.model_mean_type = ModelMeanType(model_mean_type)
|
202 |
self.model_var_type = ModelVarType(model_var_type)
|
203 |
self.loss_type = LossType(loss_type)
|
204 |
self.rescale_timesteps = rescale_timesteps
|
|
|
|
|
|
|
205 |
|
206 |
# Use float64 for accuracy.
|
207 |
betas = np.array(betas, dtype=np.float64)
|
@@ -332,10 +338,14 @@ class GaussianDiffusion:
|
|
332 |
B, C = x.shape[:2]
|
333 |
assert t.shape == (B,)
|
334 |
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
|
|
|
|
335 |
|
336 |
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
337 |
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
338 |
model_output, model_var_values = th.split(model_output, C, dim=1)
|
|
|
|
|
339 |
if self.model_var_type == ModelVarType.LEARNED:
|
340 |
model_log_variance = model_var_values
|
341 |
model_variance = th.exp(model_log_variance)
|
@@ -364,6 +374,14 @@ class GaussianDiffusion:
|
|
364 |
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
365 |
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
def process_xstart(x):
|
368 |
if denoised_fn is not None:
|
369 |
x = denoised_fn(x)
|
|
|
197 |
model_var_type,
|
198 |
loss_type,
|
199 |
rescale_timesteps=False,
|
200 |
+
conditioning_free=False,
|
201 |
+
conditioning_free_k=1,
|
202 |
+
ramp_conditioning_free=True,
|
203 |
):
|
204 |
self.model_mean_type = ModelMeanType(model_mean_type)
|
205 |
self.model_var_type = ModelVarType(model_var_type)
|
206 |
self.loss_type = LossType(loss_type)
|
207 |
self.rescale_timesteps = rescale_timesteps
|
208 |
+
self.conditioning_free = conditioning_free
|
209 |
+
self.conditioning_free_k = conditioning_free_k
|
210 |
+
self.ramp_conditioning_free = ramp_conditioning_free
|
211 |
|
212 |
# Use float64 for accuracy.
|
213 |
betas = np.array(betas, dtype=np.float64)
|
|
|
338 |
B, C = x.shape[:2]
|
339 |
assert t.shape == (B,)
|
340 |
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
341 |
+
if self.conditioning_free:
|
342 |
+
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
343 |
|
344 |
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
345 |
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
346 |
model_output, model_var_values = th.split(model_output, C, dim=1)
|
347 |
+
if self.conditioning_free:
|
348 |
+
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
|
349 |
if self.model_var_type == ModelVarType.LEARNED:
|
350 |
model_log_variance = model_var_values
|
351 |
model_variance = th.exp(model_log_variance)
|
|
|
374 |
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
375 |
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
376 |
|
377 |
+
if self.conditioning_free:
|
378 |
+
if self.ramp_conditioning_free:
|
379 |
+
assert t.shape[0] == 1 # This should only be used in inference.
|
380 |
+
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
|
381 |
+
else:
|
382 |
+
cfk = self.conditioning_free_k
|
383 |
+
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
|
384 |
+
|
385 |
def process_xstart(x):
|
386 |
if denoised_fn is not None:
|
387 |
x = denoised_fn(x)
|
utils/stft.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
|
4 |
+
Copyright (c) 2017, Prem Seetharaman
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
* Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the
|
15 |
+
documentation and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
* Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch.autograd import Variable
|
37 |
+
from scipy.signal import get_window
|
38 |
+
from librosa.util import pad_center, tiny
|
39 |
+
import librosa.util as librosa_util
|
40 |
+
|
41 |
+
|
42 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
43 |
+
n_fft=800, dtype=np.float32, norm=None):
|
44 |
+
"""
|
45 |
+
# from librosa 0.6
|
46 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
47 |
+
|
48 |
+
This is used to estimate modulation effects induced by windowing
|
49 |
+
observations in short-time fourier transforms.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
window : string, tuple, number, callable, or list-like
|
54 |
+
Window specification, as in `get_window`
|
55 |
+
|
56 |
+
n_frames : int > 0
|
57 |
+
The number of analysis frames
|
58 |
+
|
59 |
+
hop_length : int > 0
|
60 |
+
The number of samples to advance between frames
|
61 |
+
|
62 |
+
win_length : [optional]
|
63 |
+
The length of the window function. By default, this matches `n_fft`.
|
64 |
+
|
65 |
+
n_fft : int > 0
|
66 |
+
The length of each analysis frame.
|
67 |
+
|
68 |
+
dtype : np.dtype
|
69 |
+
The data type of the output
|
70 |
+
|
71 |
+
Returns
|
72 |
+
-------
|
73 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
74 |
+
The sum-squared envelope of the window function
|
75 |
+
"""
|
76 |
+
if win_length is None:
|
77 |
+
win_length = n_fft
|
78 |
+
|
79 |
+
n = n_fft + hop_length * (n_frames - 1)
|
80 |
+
x = np.zeros(n, dtype=dtype)
|
81 |
+
|
82 |
+
# Compute the squared window at the desired length
|
83 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
84 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
85 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
86 |
+
|
87 |
+
# Fill the envelope
|
88 |
+
for i in range(n_frames):
|
89 |
+
sample = i * hop_length
|
90 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class STFT(torch.nn.Module):
|
95 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
96 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
97 |
+
window='hann'):
|
98 |
+
super(STFT, self).__init__()
|
99 |
+
self.filter_length = filter_length
|
100 |
+
self.hop_length = hop_length
|
101 |
+
self.win_length = win_length
|
102 |
+
self.window = window
|
103 |
+
self.forward_transform = None
|
104 |
+
scale = self.filter_length / self.hop_length
|
105 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
106 |
+
|
107 |
+
cutoff = int((self.filter_length / 2 + 1))
|
108 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
109 |
+
np.imag(fourier_basis[:cutoff, :])])
|
110 |
+
|
111 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
112 |
+
inverse_basis = torch.FloatTensor(
|
113 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
114 |
+
|
115 |
+
if window is not None:
|
116 |
+
assert(filter_length >= win_length)
|
117 |
+
# get window and zero center pad it to filter_length
|
118 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
119 |
+
fft_window = pad_center(fft_window, filter_length)
|
120 |
+
fft_window = torch.from_numpy(fft_window).float()
|
121 |
+
|
122 |
+
# window the bases
|
123 |
+
forward_basis *= fft_window
|
124 |
+
inverse_basis *= fft_window
|
125 |
+
|
126 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
127 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
128 |
+
|
129 |
+
def transform(self, input_data):
|
130 |
+
num_batches = input_data.size(0)
|
131 |
+
num_samples = input_data.size(1)
|
132 |
+
|
133 |
+
self.num_samples = num_samples
|
134 |
+
|
135 |
+
# similar to librosa, reflect-pad the input
|
136 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
137 |
+
input_data = F.pad(
|
138 |
+
input_data.unsqueeze(1),
|
139 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
140 |
+
mode='reflect')
|
141 |
+
input_data = input_data.squeeze(1)
|
142 |
+
|
143 |
+
forward_transform = F.conv1d(
|
144 |
+
input_data,
|
145 |
+
Variable(self.forward_basis, requires_grad=False),
|
146 |
+
stride=self.hop_length,
|
147 |
+
padding=0)
|
148 |
+
|
149 |
+
cutoff = int((self.filter_length / 2) + 1)
|
150 |
+
real_part = forward_transform[:, :cutoff, :]
|
151 |
+
imag_part = forward_transform[:, cutoff:, :]
|
152 |
+
|
153 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
154 |
+
phase = torch.autograd.Variable(
|
155 |
+
torch.atan2(imag_part.data, real_part.data))
|
156 |
+
|
157 |
+
return magnitude, phase
|
158 |
+
|
159 |
+
def inverse(self, magnitude, phase):
|
160 |
+
recombine_magnitude_phase = torch.cat(
|
161 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
162 |
+
|
163 |
+
inverse_transform = F.conv_transpose1d(
|
164 |
+
recombine_magnitude_phase,
|
165 |
+
Variable(self.inverse_basis, requires_grad=False),
|
166 |
+
stride=self.hop_length,
|
167 |
+
padding=0)
|
168 |
+
|
169 |
+
if self.window is not None:
|
170 |
+
window_sum = window_sumsquare(
|
171 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
172 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
173 |
+
dtype=np.float32)
|
174 |
+
# remove modulation effects
|
175 |
+
approx_nonzero_indices = torch.from_numpy(
|
176 |
+
np.where(window_sum > tiny(window_sum))[0])
|
177 |
+
window_sum = torch.autograd.Variable(
|
178 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
179 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
180 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
181 |
+
|
182 |
+
# scale by hop ratio
|
183 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
184 |
+
|
185 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
186 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
187 |
+
|
188 |
+
return inverse_transform
|
189 |
+
|
190 |
+
def forward(self, input_data):
|
191 |
+
self.magnitude, self.phase = self.transform(input_data)
|
192 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
193 |
+
return reconstruction
|