jbetker commited on
Commit
9ad0f0e
1 Parent(s): 31f7372

Modifications to support "v1.5"

Browse files
do_tts.py CHANGED
@@ -8,14 +8,14 @@ import torch.nn.functional as F
8
  import torchaudio
9
  import progressbar
10
 
11
- from models.dvae import DiscreteVAE
12
  from models.autoregressive import UnifiedVoice
13
  from tqdm import tqdm
14
 
15
  from models.arch_util import TorchMelSpectrogram
16
- from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
17
  from models.text_voice_clip import VoiceCLIP
18
- from utils.audio import load_audio
 
19
  from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
20
  from utils.tokenizer import VoiceBpeTokenizer
21
 
@@ -23,7 +23,6 @@ pbar = None
23
  def download_models():
24
  MODELS = {
25
  'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
26
- 'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin',
27
  'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
28
  'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
29
  }
@@ -47,12 +46,14 @@ def download_models():
47
  request.urlretrieve(url, f'.models/{model_name}', show_progress)
48
  print('Done.')
49
 
 
50
  def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
51
  """
52
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
53
  """
54
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
55
- model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
 
56
 
57
 
58
  def load_conditioning(path, sample_rate=22050, cond_length=132300):
@@ -94,26 +95,26 @@ def fix_autoregressive_output(codes, stop_token):
94
  return codes
95
 
96
 
97
- def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, mean=False):
98
  """
99
  Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
100
  """
101
  with torch.no_grad():
102
- mel = dvae_model.decode(mel_codes)[0]
103
-
104
- # Pad MEL to multiples of 2048//spectrogram_compression_factor
105
- msl = mel.shape[-1]
106
- dsl = 2048 // spectrogram_compression_factor
107
  gap = dsl - (msl % dsl)
108
  if gap > 0:
109
- mel = torch.nn.functional.pad(mel, (0, gap))
110
 
111
- output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
112
  if mean:
113
- return diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
114
- model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
115
  else:
116
- return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
 
117
 
118
 
119
  if __name__ == '__main__':
@@ -145,12 +146,6 @@ if __name__ == '__main__':
145
  download_models()
146
 
147
  for voice in args.voice.split(','):
148
- print("Loading GPT TTS..")
149
- autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
150
- heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
151
- autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
152
- stop_mel_token = autoregressive.stop_mel_token
153
-
154
  print("Loading data..")
155
  tokenizer = VoiceBpeTokenizer()
156
  text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
@@ -160,7 +155,15 @@ if __name__ == '__main__':
160
  for cond_path in cond_paths:
161
  c, cond_wav = load_conditioning(cond_path)
162
  conds.append(c)
163
- conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
 
 
 
 
 
 
 
 
164
 
165
  with torch.no_grad():
166
  print("Performing autoregressive inference..")
@@ -194,20 +197,25 @@ if __name__ == '__main__':
194
  # Delete the autoregressive and clip models to free up GPU memory
195
  del samples, clip
196
 
197
- print("Loading DVAE..")
198
- dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
199
- record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
200
- dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)
201
  print("Loading Diffusion Model..")
202
- diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
203
- spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
204
- conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
 
 
205
  diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
 
 
 
 
 
 
206
  diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
207
 
208
  print("Performing vocoding..")
209
  # Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
210
  for b in range(best_results.shape[0]):
211
  code = best_results[b].unsqueeze(0)
212
- wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)
213
- torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 22050)
 
 
8
  import torchaudio
9
  import progressbar
10
 
11
+ from models.diffusion_decoder import DiffusionTts
12
  from models.autoregressive import UnifiedVoice
13
  from tqdm import tqdm
14
 
15
  from models.arch_util import TorchMelSpectrogram
 
16
  from models.text_voice_clip import VoiceCLIP
17
+ from models.vocoder import UnivNetGenerator
18
+ from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
19
  from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
20
  from utils.tokenizer import VoiceBpeTokenizer
21
 
 
23
  def download_models():
24
  MODELS = {
25
  'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
 
26
  'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
27
  'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
28
  }
 
46
  request.urlretrieve(url, f'.models/{model_name}', show_progress)
47
  print('Done.')
48
 
49
+
50
  def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
51
  """
52
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
53
  """
54
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
55
+ model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
56
+ conditioning_free=True, conditioning_free_k=1)
57
 
58
 
59
  def load_conditioning(path, sample_rate=22050, cond_length=132300):
 
95
  return codes
96
 
97
 
98
+ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
99
  """
100
  Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
101
  """
102
  with torch.no_grad():
103
+ cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
104
+ # Pad MEL to multiples of 32
105
+ msl = mel_codes.shape[-1]
106
+ dsl = 32
 
107
  gap = dsl - (msl % dsl)
108
  if gap > 0:
109
+ mel = torch.nn.functional.pad(mel_codes, (0, gap))
110
 
111
+ output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
112
  if mean:
113
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
114
+ model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
115
  else:
116
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
117
+ return denormalize_tacotron_mel(mel)[:,:,:msl*4]
118
 
119
 
120
  if __name__ == '__main__':
 
146
  download_models()
147
 
148
  for voice in args.voice.split(','):
 
 
 
 
 
 
149
  print("Loading data..")
150
  tokenizer = VoiceBpeTokenizer()
151
  text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
 
155
  for cond_path in cond_paths:
156
  c, cond_wav = load_conditioning(cond_path)
157
  conds.append(c)
158
+ conds = torch.stack(conds, dim=1)
159
+ cond_diffusion = cond_wav[:, :88200] # The diffusion model expects <= 88200 conditioning samples.
160
+
161
+ print("Loading GPT TTS..")
162
+ autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
163
+ heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False,
164
+ average_conditioning_embeddings=True).cuda().eval()
165
+ autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
166
+ stop_mel_token = autoregressive.stop_mel_token
167
 
168
  with torch.no_grad():
169
  print("Performing autoregressive inference..")
 
197
  # Delete the autoregressive and clip models to free up GPU memory
198
  del samples, clip
199
 
 
 
 
 
200
  print("Loading Diffusion Model..")
201
+ diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
202
+ channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], token_conditioning_resolutions=[1,4,8],
203
+ dropout=0, attention_resolutions=[4,8], num_heads=8, kernel_size=3, scale_factor=2,
204
+ time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
205
+ conditioning_expansion=1)
206
  diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
207
+ diffusion = diffusion.cuda().eval()
208
+ print("Loading vocoder..")
209
+ vocoder = UnivNetGenerator()
210
+ vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
211
+ vocoder = vocoder.cuda()
212
+ vocoder.eval(inference=True)
213
  diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
214
 
215
  print("Performing vocoding..")
216
  # Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
217
  for b in range(best_results.shape[0]):
218
  code = best_results[b].unsqueeze(0)
219
+ mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
220
+ wav = vocoder.inference(mel)
221
+ torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)
models/autoregressive.py CHANGED
@@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
192
  embedding_dim,
193
  attn_blocks=6,
194
  num_attn_heads=4,
195
- do_checkpointing=False):
 
196
  super().__init__()
197
  attn = []
198
  self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
@@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module):
201
  self.attn = nn.Sequential(*attn)
202
  self.dim = embedding_dim
203
  self.do_checkpointing = do_checkpointing
 
204
 
205
  def forward(self, x):
206
  h = self.init(x)
207
  h = self.attn(h)
208
- return h[:, :, 0]
 
 
 
209
 
210
 
211
  class LearnedPositionEmbeddings(nn.Module):
@@ -275,7 +280,7 @@ class UnifiedVoice(nn.Module):
275
  mel_length_compression=1024, number_text_tokens=256,
276
  start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
277
  stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
278
- checkpointing=True):
279
  """
280
  Args:
281
  layers: Number of layers in transformer stack.
@@ -294,6 +299,7 @@ class UnifiedVoice(nn.Module):
294
  train_solo_embeddings:
295
  use_mel_codes_as_input:
296
  checkpointing:
 
297
  """
298
  super().__init__()
299
 
@@ -311,6 +317,7 @@ class UnifiedVoice(nn.Module):
311
  self.max_conditioning_inputs = max_conditioning_inputs
312
  self.mel_length_compression = mel_length_compression
313
  self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
 
314
  self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
315
  if use_mel_codes_as_input:
316
  self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
@@ -408,6 +415,8 @@ class UnifiedVoice(nn.Module):
408
  for j in range(speech_conditioning_input.shape[1]):
409
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
410
  conds = torch.stack(conds, dim=1)
 
 
411
 
412
  text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
413
  text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
@@ -446,6 +455,8 @@ class UnifiedVoice(nn.Module):
446
  for j in range(speech_conditioning_input.shape[1]):
447
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
448
  conds = torch.stack(conds, dim=1)
 
 
449
 
450
  text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
451
  text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
@@ -472,6 +483,8 @@ class UnifiedVoice(nn.Module):
472
  for j in range(speech_conditioning_input.shape[1]):
473
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
474
  conds = torch.stack(conds, dim=1)
 
 
475
 
476
  mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
477
  if raw_mels is not None:
@@ -508,6 +521,8 @@ class UnifiedVoice(nn.Module):
508
  for j in range(speech_conditioning_input.shape[1]):
509
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
510
  conds = torch.stack(conds, dim=1)
 
 
511
 
512
  emb = torch.cat([conds, text_emb], dim=1)
513
  self.inference_model.store_mel_emb(emb)
 
192
  embedding_dim,
193
  attn_blocks=6,
194
  num_attn_heads=4,
195
+ do_checkpointing=False,
196
+ mean=False):
197
  super().__init__()
198
  attn = []
199
  self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
 
202
  self.attn = nn.Sequential(*attn)
203
  self.dim = embedding_dim
204
  self.do_checkpointing = do_checkpointing
205
+ self.mean = mean
206
 
207
  def forward(self, x):
208
  h = self.init(x)
209
  h = self.attn(h)
210
+ if self.mean:
211
+ return h.mean(dim=2)
212
+ else:
213
+ return h[:, :, 0]
214
 
215
 
216
  class LearnedPositionEmbeddings(nn.Module):
 
280
  mel_length_compression=1024, number_text_tokens=256,
281
  start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
282
  stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
283
+ checkpointing=True, average_conditioning_embeddings=False):
284
  """
285
  Args:
286
  layers: Number of layers in transformer stack.
 
299
  train_solo_embeddings:
300
  use_mel_codes_as_input:
301
  checkpointing:
302
+ average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
303
  """
304
  super().__init__()
305
 
 
317
  self.max_conditioning_inputs = max_conditioning_inputs
318
  self.mel_length_compression = mel_length_compression
319
  self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
320
+ self.average_conditioning_embeddings = average_conditioning_embeddings
321
  self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
322
  if use_mel_codes_as_input:
323
  self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
 
415
  for j in range(speech_conditioning_input.shape[1]):
416
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
417
  conds = torch.stack(conds, dim=1)
418
+ if self.average_conditioning_embeddings:
419
+ conds = conds.mean(dim=1).unsqueeze(1)
420
 
421
  text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
422
  text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
 
455
  for j in range(speech_conditioning_input.shape[1]):
456
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
457
  conds = torch.stack(conds, dim=1)
458
+ if self.average_conditioning_embeddings:
459
+ conds = conds.mean(dim=1).unsqueeze(1)
460
 
461
  text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
462
  text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
 
483
  for j in range(speech_conditioning_input.shape[1]):
484
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
485
  conds = torch.stack(conds, dim=1)
486
+ if self.average_conditioning_embeddings:
487
+ conds = conds.mean(dim=1).unsqueeze(1)
488
 
489
  mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
490
  if raw_mels is not None:
 
521
  for j in range(speech_conditioning_input.shape[1]):
522
  conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
523
  conds = torch.stack(conds, dim=1)
524
+ if self.average_conditioning_embeddings:
525
+ conds = conds.mean(dim=1).unsqueeze(1)
526
 
527
  emb = torch.cat([conds, text_emb], dim=1)
528
  self.inference_model.store_mel_emb(emb)
models/{discrete_diffusion_vocoder.py → diffusion_decoder.py} RENAMED
@@ -3,17 +3,36 @@ This model is based on OpenAI's UNet from improved diffusion, with modifications
3
  and an audio conditioning input. It has also been simplified somewhat.
4
  Credit: https://github.com/openai/improved-diffusion
5
  """
6
-
7
-
8
  import math
9
  from abc import abstractmethod
10
 
11
  import torch
12
  import torch.nn as nn
 
 
 
 
 
13
 
14
  from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def timestep_embedding(timesteps, dim, max_period=10000):
18
  """
19
  Create sinusoidal timestep embeddings.
@@ -62,63 +81,36 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
62
  return x
63
 
64
 
65
- class TimestepResBlock(TimestepBlock):
66
- """
67
- A residual block that can optionally change the number of channels.
68
-
69
- :param channels: the number of input channels.
70
- :param emb_channels: the number of timestep embedding channels.
71
- :param dropout: the rate of dropout.
72
- :param out_channels: if specified, the number of out channels.
73
- :param use_conv: if True and out_channels is specified, use a spatial
74
- convolution instead of a smaller 1x1 convolution to change the
75
- channels in the skip connection.
76
- :param dims: determines if the signal is 1D, 2D, or 3D.
77
- :param up: if True, use this block for upsampling.
78
- :param down: if True, use this block for downsampling.
79
- """
80
-
81
  def __init__(
82
  self,
83
  channels,
84
  emb_channels,
85
  dropout,
86
  out_channels=None,
87
- use_conv=False,
88
- use_scale_shift_norm=False,
89
- up=False,
90
- down=False,
91
  kernel_size=3,
 
 
92
  ):
93
  super().__init__()
94
  self.channels = channels
95
  self.emb_channels = emb_channels
96
  self.dropout = dropout
97
  self.out_channels = out_channels or channels
98
- self.use_conv = use_conv
99
  self.use_scale_shift_norm = use_scale_shift_norm
100
- padding = 1 if kernel_size == 3 else (2 if kernel_size == 5 else 0)
 
 
101
 
102
  self.in_layers = nn.Sequential(
103
  normalization(channels),
104
  nn.SiLU(),
105
- nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
106
  )
107
 
108
- self.updown = up or down
109
-
110
- if up:
111
- self.h_upd = Upsample(channels, False, dims)
112
- self.x_upd = Upsample(channels, False, dims)
113
- elif down:
114
- self.h_upd = Downsample(channels, False, dims)
115
- self.x_upd = Downsample(channels, False, dims)
116
- else:
117
- self.h_upd = self.x_upd = nn.Identity()
118
-
119
  self.emb_layers = nn.Sequential(
120
  nn.SiLU(),
121
- nn.Linear(
122
  emb_channels,
123
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
124
  ),
@@ -134,22 +126,23 @@ class TimestepResBlock(TimestepBlock):
134
 
135
  if self.out_channels == channels:
136
  self.skip_connection = nn.Identity()
137
- elif use_conv:
138
- self.skip_connection = nn.Conv1d(
139
- channels, self.out_channels, kernel_size, padding=padding
140
- )
141
  else:
142
- self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
143
 
144
  def forward(self, x, emb):
145
- if self.updown:
146
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
147
- h = in_rest(x)
148
- h = self.h_upd(h)
149
- x = self.x_upd(x)
150
- h = in_conv(h)
151
- else:
152
- h = self.in_layers(x)
 
 
 
 
 
153
  emb_out = self.emb_layers(emb).type(h.dtype)
154
  while len(emb_out.shape) < len(h.shape):
155
  emb_out = emb_out[..., None]
@@ -164,37 +157,52 @@ class TimestepResBlock(TimestepBlock):
164
  return self.skip_connection(x) + h
165
 
166
 
167
- class DiscreteSpectrogramConditioningBlock(nn.Module):
168
- def __init__(self, dvae_channels, channels, level):
 
 
 
 
169
  super().__init__()
170
- self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
171
- normalization(channels),
172
- nn.SiLU(),
173
- nn.Conv1d(channels, channels, kernel_size=3))
174
- self.level = level
 
 
175
 
 
 
176
  """
177
- Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
178
-
179
- :param x: bxcxS waveform latent
180
- :param codes: bxN discrete codes, N <= S
181
  """
182
- def forward(self, x, dvae_in):
183
- b, c, S = x.shape
184
- _, q, N = dvae_in.shape
185
- emb = self.intg(dvae_in)
186
- emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
187
- return torch.cat([x, emb], dim=1)
 
 
 
 
 
 
 
 
188
 
189
 
190
- class DiscreteDiffusionVocoder(nn.Module):
191
  """
192
  The full UNet model with attention and timestep embedding.
193
 
194
- Customized to be conditioned on a spectrogram prior.
 
195
 
196
  :param in_channels: channels in the input Tensor.
197
- :param spectrogram_channels: channels in the conditioning spectrogram.
198
  :param model_channels: base channel count for the model.
199
  :param out_channels: channels in the output Tensor.
200
  :param num_res_blocks: number of residual blocks per downsample.
@@ -206,7 +214,6 @@ class DiscreteDiffusionVocoder(nn.Module):
206
  :param channel_mult: channel multiplier for each level of the UNet.
207
  :param conv_resample: if True, use learned convolutions for upsampling and
208
  downsampling.
209
- :param dims: determines if the signal is 1D, 2D, or 3D.
210
  :param num_heads: the number of attention heads in each attention layer.
211
  :param num_heads_channels: if specified, ignore num_heads and instead use
212
  a fixed channel width per attention head.
@@ -222,34 +229,43 @@ class DiscreteDiffusionVocoder(nn.Module):
222
  self,
223
  model_channels,
224
  in_channels=1,
 
 
 
 
225
  out_channels=2, # mean and variance
226
- dvae_dim=512,
227
  dropout=0,
228
  # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
229
  channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
230
  num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
231
  # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
232
  # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
233
- spectrogram_conditioning_resolutions=(512,),
234
  attention_resolutions=(512,1024,2048),
235
  conv_resample=True,
236
- dims=1,
237
  use_fp16=False,
238
  num_heads=1,
239
  num_head_channels=-1,
240
  num_heads_upsample=-1,
241
- use_scale_shift_norm=False,
242
- resblock_updown=False,
243
  kernel_size=3,
244
  scale_factor=2,
245
- conditioning_inputs_provided=True,
246
  time_embed_dim_multiplier=4,
 
 
 
 
 
 
 
 
247
  ):
248
  super().__init__()
249
 
250
  if num_heads_upsample == -1:
251
  num_heads_upsample = num_heads
252
 
 
 
253
  self.in_channels = in_channels
254
  self.model_channels = model_channels
255
  self.out_channels = out_channels
@@ -257,53 +273,110 @@ class DiscreteDiffusionVocoder(nn.Module):
257
  self.dropout = dropout
258
  self.channel_mult = channel_mult
259
  self.conv_resample = conv_resample
260
- self.dtype = torch.float16 if use_fp16 else torch.float32
261
  self.num_heads = num_heads
262
  self.num_head_channels = num_head_channels
263
  self.num_heads_upsample = num_heads_upsample
264
- self.dims = dims
265
-
 
 
 
 
266
  padding = 1 if kernel_size == 3 else 2
 
267
 
268
  time_embed_dim = model_channels * time_embed_dim_multiplier
269
  self.time_embed = nn.Sequential(
270
- nn.Linear(model_channels, time_embed_dim),
271
  nn.SiLU(),
272
- nn.Linear(time_embed_dim, time_embed_dim),
273
  )
274
 
275
- self.conditioning_enabled = conditioning_inputs_provided
276
- if conditioning_inputs_provided:
277
- self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
278
- attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
279
-
280
- seqlyr = TimestepEmbedSequential(
281
- nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
- seqlyr.level = 0
284
- self.input_blocks = nn.ModuleList([seqlyr])
285
- spectrogram_blocks = []
 
 
 
 
 
 
 
286
  self._feature_size = model_channels
287
  input_block_chans = [model_channels]
288
  ch = model_channels
289
  ds = 1
290
 
291
  for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
292
- if ds in spectrogram_conditioning_resolutions:
293
- spec_cond_block = DiscreteSpectrogramConditioningBlock(dvae_dim, ch, 2 ** level)
294
- self.input_blocks.append(spec_cond_block)
295
- spectrogram_blocks.append(spec_cond_block)
296
- ch *= 2
297
 
298
  for _ in range(num_blocks):
299
  layers = [
300
- TimestepResBlock(
301
  ch,
302
  time_embed_dim,
303
  dropout,
304
  out_channels=int(mult * model_channels),
305
- use_scale_shift_norm=use_scale_shift_norm,
306
  kernel_size=kernel_size,
 
 
307
  )
308
  ]
309
  ch = int(mult * model_channels)
@@ -315,54 +388,44 @@ class DiscreteDiffusionVocoder(nn.Module):
315
  num_head_channels=num_head_channels,
316
  )
317
  )
318
- layer = TimestepEmbedSequential(*layers)
319
- layer.level = 2 ** level
320
- self.input_blocks.append(layer)
321
  self._feature_size += ch
322
  input_block_chans.append(ch)
323
  if level != len(channel_mult) - 1:
324
  out_ch = ch
325
- upblk = TimestepEmbedSequential(
326
- TimestepResBlock(
327
- ch,
328
- time_embed_dim,
329
- dropout,
330
- out_channels=out_ch,
331
- use_scale_shift_norm=use_scale_shift_norm,
332
- down=True,
333
- kernel_size=kernel_size,
334
- )
335
- if resblock_updown
336
- else Downsample(
337
- ch, conv_resample, out_channels=out_ch, factor=scale_factor
338
  )
339
  )
340
- upblk.level = 2 ** level
341
- self.input_blocks.append(upblk)
342
  ch = out_ch
343
  input_block_chans.append(ch)
344
  ds *= 2
345
  self._feature_size += ch
346
 
347
  self.middle_block = TimestepEmbedSequential(
348
- TimestepResBlock(
349
  ch,
350
  time_embed_dim,
351
  dropout,
352
- use_scale_shift_norm=use_scale_shift_norm,
353
  kernel_size=kernel_size,
 
 
354
  ),
355
  AttentionBlock(
356
  ch,
357
  num_heads=num_heads,
358
  num_head_channels=num_head_channels,
359
  ),
360
- TimestepResBlock(
361
  ch,
362
  time_embed_dim,
363
  dropout,
364
- use_scale_shift_norm=use_scale_shift_norm,
365
  kernel_size=kernel_size,
 
 
366
  ),
367
  )
368
  self._feature_size += ch
@@ -372,13 +435,14 @@ class DiscreteDiffusionVocoder(nn.Module):
372
  for i in range(num_blocks + 1):
373
  ich = input_block_chans.pop()
374
  layers = [
375
- TimestepResBlock(
376
  ch + ich,
377
  time_embed_dim,
378
  dropout,
379
  out_channels=int(model_channels * mult),
380
- use_scale_shift_norm=use_scale_shift_norm,
381
  kernel_size=kernel_size,
 
 
382
  )
383
  ]
384
  ch = int(model_channels * mult)
@@ -393,22 +457,10 @@ class DiscreteDiffusionVocoder(nn.Module):
393
  if level and i == num_blocks:
394
  out_ch = ch
395
  layers.append(
396
- TimestepResBlock(
397
- ch,
398
- time_embed_dim,
399
- dropout,
400
- out_channels=out_ch,
401
- use_scale_shift_norm=use_scale_shift_norm,
402
- up=True,
403
- kernel_size=kernel_size,
404
- )
405
- if resblock_updown
406
- else Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
407
  )
408
  ds //= 2
409
- layer = TimestepEmbedSequential(*layers)
410
- layer.level = 2 ** level
411
- self.output_blocks.append(layer)
412
  self._feature_size += ch
413
 
414
  self.out = nn.Sequential(
@@ -417,52 +469,130 @@ class DiscreteDiffusionVocoder(nn.Module):
417
  zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
418
  )
419
 
420
- def forward(self, x, timesteps, spectrogram, conditioning_input=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  """
422
  Apply the model to an input batch.
423
 
424
  :param x: an [N x C x ...] Tensor of inputs.
425
  :param timesteps: a 1-D batch of timesteps.
426
- :param y: an [N] Tensor of labels, if class-conditional.
 
 
 
427
  :return: an [N x C x ...] Tensor of outputs.
428
  """
429
- assert x.shape[-1] % 2048 == 0 # This model operates at base//2048 at it's bottom levels, thus this requirement.
430
- if self.conditioning_enabled:
431
- assert conditioning_input is not None
432
-
433
- hs = []
434
- emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
435
- if self.conditioning_enabled:
436
- emb2 = self.contextual_embedder(conditioning_input)
437
- emb = emb1 + emb2
438
- else:
439
- emb = emb1
440
-
441
- h = x.type(self.dtype)
442
- for k, module in enumerate(self.input_blocks):
443
- if isinstance(module, DiscreteSpectrogramConditioningBlock):
444
- h = module(h, spectrogram)
 
 
 
 
 
 
 
 
445
  else:
446
- h = module(h, emb)
447
- hs.append(h)
448
- h = self.middle_block(h, emb)
449
- for module in self.output_blocks:
450
- h = torch.cat([h, hs.pop()], dim=1)
451
- h = module(h, emb)
452
- h = h.type(x.dtype)
453
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
 
456
- # Test for ~4 second audio clip at 22050Hz
457
  if __name__ == '__main__':
458
- clip = torch.randn(2, 1, 40960)
459
- spec = torch.randn(2,80,160)
460
- cond = torch.randn(2, 1, 40960)
461
- ts = torch.LongTensor([555, 556])
462
- model = DiscreteDiffusionVocoder(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8],
463
- num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
464
- dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
465
- conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
466
- dvae_dim=80)
467
-
468
- print(model(clip, ts, spec, cond).shape)
 
 
 
 
 
 
 
 
 
 
3
  and an audio conditioning input. It has also been simplified somewhat.
4
  Credit: https://github.com/openai/improved-diffusion
5
  """
6
+ import functools
 
7
  import math
8
  from abc import abstractmethod
9
 
10
  import torch
11
  import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import autocast
14
+ from torch.nn import Linear
15
+ from torch.utils.checkpoint import checkpoint
16
+ from x_transformers import ContinuousTransformerWrapper, Encoder
17
 
18
  from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
19
 
20
 
21
+ def is_latent(t):
22
+ return t.dtype == torch.float
23
+
24
+
25
+ def is_sequence(t):
26
+ return t.dtype == torch.long
27
+
28
+
29
+ def ceil_multiple(base, multiple):
30
+ res = base % multiple
31
+ if res == 0:
32
+ return base
33
+ return base + (multiple - res)
34
+
35
+
36
  def timestep_embedding(timesteps, dim, max_period=10000):
37
  """
38
  Create sinusoidal timestep embeddings.
 
81
  return x
82
 
83
 
84
+ class ResBlock(TimestepBlock):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def __init__(
86
  self,
87
  channels,
88
  emb_channels,
89
  dropout,
90
  out_channels=None,
 
 
 
 
91
  kernel_size=3,
92
+ efficient_config=True,
93
+ use_scale_shift_norm=False,
94
  ):
95
  super().__init__()
96
  self.channels = channels
97
  self.emb_channels = emb_channels
98
  self.dropout = dropout
99
  self.out_channels = out_channels or channels
 
100
  self.use_scale_shift_norm = use_scale_shift_norm
101
+ padding = {1: 0, 3: 1, 5: 2}[kernel_size]
102
+ eff_kernel = 1 if efficient_config else 3
103
+ eff_padding = 0 if efficient_config else 1
104
 
105
  self.in_layers = nn.Sequential(
106
  normalization(channels),
107
  nn.SiLU(),
108
+ nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  self.emb_layers = nn.Sequential(
112
  nn.SiLU(),
113
+ Linear(
114
  emb_channels,
115
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
116
  ),
 
126
 
127
  if self.out_channels == channels:
128
  self.skip_connection = nn.Identity()
 
 
 
 
129
  else:
130
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
131
 
132
  def forward(self, x, emb):
133
+ """
134
+ Apply the block to a Tensor, conditioned on a timestep embedding.
135
+
136
+ :param x: an [N x C x ...] Tensor of features.
137
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
138
+ :return: an [N x C x ...] Tensor of outputs.
139
+ """
140
+ return checkpoint(
141
+ self._forward, x, emb
142
+ )
143
+
144
+ def _forward(self, x, emb):
145
+ h = self.in_layers(x)
146
  emb_out = self.emb_layers(emb).type(h.dtype)
147
  while len(emb_out.shape) < len(h.shape):
148
  emb_out = emb_out[..., None]
 
157
  return self.skip_connection(x) + h
158
 
159
 
160
+ class CheckpointedLayer(nn.Module):
161
+ """
162
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
163
+ checkpoint for all other args.
164
+ """
165
+ def __init__(self, wrap):
166
  super().__init__()
167
+ self.wrap = wrap
168
+
169
+ def forward(self, x, *args, **kwargs):
170
+ for k, v in kwargs.items():
171
+ assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
172
+ partial = functools.partial(self.wrap, **kwargs)
173
+ return torch.utils.checkpoint.checkpoint(partial, x, *args)
174
 
175
+
176
+ class CheckpointedXTransformerEncoder(nn.Module):
177
  """
178
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
179
+ to channels-last that XTransformer expects.
 
 
180
  """
181
+ def __init__(self, needs_permute=True, **xtransformer_kwargs):
182
+ super().__init__()
183
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
184
+ self.needs_permute = needs_permute
185
+
186
+ for i in range(len(self.transformer.attn_layers.layers)):
187
+ n, b, r = self.transformer.attn_layers.layers[i]
188
+ self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
189
+
190
+ def forward(self, x, **kwargs):
191
+ if self.needs_permute:
192
+ x = x.permute(0,2,1)
193
+ h = self.transformer(x, **kwargs)
194
+ return h.permute(0,2,1)
195
 
196
 
197
+ class DiffusionTts(nn.Module):
198
  """
199
  The full UNet model with attention and timestep embedding.
200
 
201
+ Customized to be conditioned on an aligned prior derived from a autoregressive
202
+ GPT-style model.
203
 
204
  :param in_channels: channels in the input Tensor.
205
+ :param in_latent_channels: channels from the input latent.
206
  :param model_channels: base channel count for the model.
207
  :param out_channels: channels in the output Tensor.
208
  :param num_res_blocks: number of residual blocks per downsample.
 
214
  :param channel_mult: channel multiplier for each level of the UNet.
215
  :param conv_resample: if True, use learned convolutions for upsampling and
216
  downsampling.
 
217
  :param num_heads: the number of attention heads in each attention layer.
218
  :param num_heads_channels: if specified, ignore num_heads and instead use
219
  a fixed channel width per attention head.
 
229
  self,
230
  model_channels,
231
  in_channels=1,
232
+ in_latent_channels=1024,
233
+ in_tokens=8193,
234
+ conditioning_dim_factor=8,
235
+ conditioning_expansion=4,
236
  out_channels=2, # mean and variance
 
237
  dropout=0,
238
  # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
239
  channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
240
  num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
241
  # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
242
  # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
243
+ token_conditioning_resolutions=(1,16,),
244
  attention_resolutions=(512,1024,2048),
245
  conv_resample=True,
 
246
  use_fp16=False,
247
  num_heads=1,
248
  num_head_channels=-1,
249
  num_heads_upsample=-1,
 
 
250
  kernel_size=3,
251
  scale_factor=2,
 
252
  time_embed_dim_multiplier=4,
253
+ freeze_main_net=False,
254
+ efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
255
+ use_scale_shift_norm=True,
256
+ # Parameters for regularization.
257
+ unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
258
+ # Parameters for super-sampling.
259
+ super_sampling=False,
260
+ super_sampling_max_noising_factor=.1,
261
  ):
262
  super().__init__()
263
 
264
  if num_heads_upsample == -1:
265
  num_heads_upsample = num_heads
266
 
267
+ if super_sampling:
268
+ in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
269
  self.in_channels = in_channels
270
  self.model_channels = model_channels
271
  self.out_channels = out_channels
 
273
  self.dropout = dropout
274
  self.channel_mult = channel_mult
275
  self.conv_resample = conv_resample
 
276
  self.num_heads = num_heads
277
  self.num_head_channels = num_head_channels
278
  self.num_heads_upsample = num_heads_upsample
279
+ self.super_sampling_enabled = super_sampling
280
+ self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
281
+ self.unconditioned_percentage = unconditioned_percentage
282
+ self.enable_fp16 = use_fp16
283
+ self.alignment_size = 2 ** (len(channel_mult)+1)
284
+ self.freeze_main_net = freeze_main_net
285
  padding = 1 if kernel_size == 3 else 2
286
+ down_kernel = 1 if efficient_convs else 3
287
 
288
  time_embed_dim = model_channels * time_embed_dim_multiplier
289
  self.time_embed = nn.Sequential(
290
+ Linear(model_channels, time_embed_dim),
291
  nn.SiLU(),
292
+ Linear(time_embed_dim, time_embed_dim),
293
  )
294
 
295
+ conditioning_dim = model_channels * conditioning_dim_factor
296
+ # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
297
+ # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
298
+ # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
299
+ # transformer network.
300
+ self.code_converter = nn.Sequential(
301
+ nn.Embedding(in_tokens, conditioning_dim),
302
+ CheckpointedXTransformerEncoder(
303
+ needs_permute=False,
304
+ max_seq_len=-1,
305
+ use_pos_emb=False,
306
+ attn_layers=Encoder(
307
+ dim=conditioning_dim,
308
+ depth=3,
309
+ heads=num_heads,
310
+ ff_dropout=dropout,
311
+ attn_dropout=dropout,
312
+ use_rmsnorm=True,
313
+ ff_glu=True,
314
+ rotary_emb_dim=True,
315
+ )
316
+ ))
317
+ self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
318
+ self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
319
+ if in_channels > 60: # It's a spectrogram.
320
+ self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
321
+ CheckpointedXTransformerEncoder(
322
+ needs_permute=True,
323
+ max_seq_len=-1,
324
+ use_pos_emb=False,
325
+ attn_layers=Encoder(
326
+ dim=conditioning_dim,
327
+ depth=4,
328
+ heads=num_heads,
329
+ ff_dropout=dropout,
330
+ attn_dropout=dropout,
331
+ use_rmsnorm=True,
332
+ ff_glu=True,
333
+ rotary_emb_dim=True,
334
+ )
335
+ ))
336
+ else:
337
+ self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
338
+ attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
339
+ self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
340
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
341
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
342
+ ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
343
+ AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
344
+ ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
345
+ AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
346
+ ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
347
  )
348
+ self.conditioning_expansion = conditioning_expansion
349
+
350
+ self.input_blocks = nn.ModuleList(
351
+ [
352
+ TimestepEmbedSequential(
353
+ nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
354
+ )
355
+ ]
356
+ )
357
+ token_conditioning_blocks = []
358
  self._feature_size = model_channels
359
  input_block_chans = [model_channels]
360
  ch = model_channels
361
  ds = 1
362
 
363
  for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
364
+ if ds in token_conditioning_resolutions:
365
+ token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
366
+ token_conditioning_block.weight.data *= .02
367
+ self.input_blocks.append(token_conditioning_block)
368
+ token_conditioning_blocks.append(token_conditioning_block)
369
 
370
  for _ in range(num_blocks):
371
  layers = [
372
+ ResBlock(
373
  ch,
374
  time_embed_dim,
375
  dropout,
376
  out_channels=int(mult * model_channels),
 
377
  kernel_size=kernel_size,
378
+ efficient_config=efficient_convs,
379
+ use_scale_shift_norm=use_scale_shift_norm,
380
  )
381
  ]
382
  ch = int(mult * model_channels)
 
388
  num_head_channels=num_head_channels,
389
  )
390
  )
391
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
 
 
392
  self._feature_size += ch
393
  input_block_chans.append(ch)
394
  if level != len(channel_mult) - 1:
395
  out_ch = ch
396
+ self.input_blocks.append(
397
+ TimestepEmbedSequential(
398
+ Downsample(
399
+ ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
 
 
 
 
 
 
 
 
 
400
  )
401
  )
402
+ )
 
403
  ch = out_ch
404
  input_block_chans.append(ch)
405
  ds *= 2
406
  self._feature_size += ch
407
 
408
  self.middle_block = TimestepEmbedSequential(
409
+ ResBlock(
410
  ch,
411
  time_embed_dim,
412
  dropout,
 
413
  kernel_size=kernel_size,
414
+ efficient_config=efficient_convs,
415
+ use_scale_shift_norm=use_scale_shift_norm,
416
  ),
417
  AttentionBlock(
418
  ch,
419
  num_heads=num_heads,
420
  num_head_channels=num_head_channels,
421
  ),
422
+ ResBlock(
423
  ch,
424
  time_embed_dim,
425
  dropout,
 
426
  kernel_size=kernel_size,
427
+ efficient_config=efficient_convs,
428
+ use_scale_shift_norm=use_scale_shift_norm,
429
  ),
430
  )
431
  self._feature_size += ch
 
435
  for i in range(num_blocks + 1):
436
  ich = input_block_chans.pop()
437
  layers = [
438
+ ResBlock(
439
  ch + ich,
440
  time_embed_dim,
441
  dropout,
442
  out_channels=int(model_channels * mult),
 
443
  kernel_size=kernel_size,
444
+ efficient_config=efficient_convs,
445
+ use_scale_shift_norm=use_scale_shift_norm,
446
  )
447
  ]
448
  ch = int(model_channels * mult)
 
457
  if level and i == num_blocks:
458
  out_ch = ch
459
  layers.append(
460
+ Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
 
 
 
 
 
 
 
 
 
 
461
  )
462
  ds //= 2
463
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
 
 
464
  self._feature_size += ch
465
 
466
  self.out = nn.Sequential(
 
469
  zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
470
  )
471
 
472
+ def fix_alignment(self, x, aligned_conditioning):
473
+ """
474
+ The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
475
+ padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
476
+ """
477
+ cm = ceil_multiple(x.shape[-1], self.alignment_size)
478
+ if cm != 0:
479
+ pc = (cm-x.shape[-1])/x.shape[-1]
480
+ x = F.pad(x, (0,cm-x.shape[-1]))
481
+ # Also fix aligned_latent, which is aligned to x.
482
+ if is_latent(aligned_conditioning):
483
+ aligned_conditioning = torch.cat([aligned_conditioning,
484
+ self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
485
+ else:
486
+ aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
487
+ return x, aligned_conditioning
488
+
489
+ def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
490
  """
491
  Apply the model to an input batch.
492
 
493
  :param x: an [N x C x ...] Tensor of inputs.
494
  :param timesteps: a 1-D batch of timesteps.
495
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
496
+ :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
497
+ :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
498
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
499
  :return: an [N x C x ...] Tensor of outputs.
500
  """
501
+ assert conditioning_input is not None
502
+ if self.super_sampling_enabled:
503
+ assert lr_input is not None
504
+ if self.training and self.super_sampling_max_noising_factor > 0:
505
+ noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
506
+ lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
507
+ lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
508
+ x = torch.cat([x, lr_input], dim=1)
509
+
510
+ # Shuffle aligned_latent to BxCxS format
511
+ if is_latent(aligned_conditioning):
512
+ aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
513
+
514
+ # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
515
+ orig_x_shape = x.shape[-1]
516
+ x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
517
+
518
+ with autocast(x.device.type, enabled=self.enable_fp16):
519
+ hs = []
520
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
521
+
522
+ # Note: this block does not need to repeated on inference, since it is not timestep-dependent.
523
+ if conditioning_free:
524
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
525
  else:
526
+ cond_emb = self.contextual_embedder(conditioning_input)
527
+ if len(cond_emb.shape) == 3: # Just take the first element.
528
+ cond_emb = cond_emb[:, :, 0]
529
+ if is_latent(aligned_conditioning):
530
+ code_emb = self.latent_converter(aligned_conditioning)
531
+ else:
532
+ code_emb = self.code_converter(aligned_conditioning)
533
+ cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
534
+ code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
535
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
536
+ if self.training and self.unconditioned_percentage > 0:
537
+ unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
538
+ device=code_emb.device) < self.unconditioned_percentage
539
+ code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
540
+ code_emb)
541
+
542
+ # Everything after this comment is timestep dependent.
543
+ code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
544
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
545
+
546
+ first = True
547
+ time_emb = time_emb.float()
548
+ h = x
549
+ for k, module in enumerate(self.input_blocks):
550
+ if isinstance(module, nn.Conv1d):
551
+ h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
552
+ h = h + h_tok
553
+ else:
554
+ with autocast(x.device.type, enabled=self.enable_fp16 and not first):
555
+ # First block has autocast disabled to allow a high precision signal to be properly vectorized.
556
+ h = module(h, time_emb)
557
+ hs.append(h)
558
+ first = False
559
+ h = self.middle_block(h, time_emb)
560
+ for module in self.output_blocks:
561
+ h = torch.cat([h, hs.pop()], dim=1)
562
+ h = module(h, time_emb)
563
+
564
+ # Last block also has autocast disabled for high-precision outputs.
565
+ h = h.float()
566
+ out = self.out(h)
567
+
568
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
569
+ extraneous_addition = 0
570
+ params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
571
+ for p in params:
572
+ extraneous_addition = extraneous_addition + p.mean()
573
+ out = out + extraneous_addition * 0
574
+
575
+ return out[:, :, :orig_x_shape]
576
 
577
 
 
578
  if __name__ == '__main__':
579
+ clip = torch.randn(2, 1, 32868)
580
+ aligned_latent = torch.randn(2,388,1024)
581
+ aligned_sequence = torch.randint(0,8192,(2,388))
582
+ cond = torch.randn(2, 1, 44000)
583
+ ts = torch.LongTensor([600, 600])
584
+ model = DiffusionTts(128,
585
+ channel_mult=[1,1.5,2, 3, 4, 6, 8],
586
+ num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
587
+ token_conditioning_resolutions=[1,4,16,64],
588
+ attention_resolutions=[],
589
+ num_heads=8,
590
+ kernel_size=3,
591
+ scale_factor=2,
592
+ time_embed_dim_multiplier=4,
593
+ super_sampling=False,
594
+ efficient_convs=False)
595
+ # Test with latent aligned conditioning
596
+ o = model(clip, ts, aligned_latent, cond)
597
+ # Test with sequence aligned conditioning
598
+ o = model(clip, ts, aligned_sequence, cond)
models/dvae.py DELETED
@@ -1,390 +0,0 @@
1
- import functools
2
- from math import sqrt
3
-
4
- import torch
5
- import torch.distributed as distributed
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
-
10
-
11
- def default(val, d):
12
- return val if val is not None else d
13
-
14
-
15
- def eval_decorator(fn):
16
- def inner(model, *args, **kwargs):
17
- was_training = model.training
18
- model.eval()
19
- out = fn(model, *args, **kwargs)
20
- model.train(was_training)
21
- return out
22
- return inner
23
-
24
-
25
- # Quantizer implemented by the rosinality vqvae repo.
26
- # Credit: https://github.com/rosinality/vq-vae-2-pytorch
27
- class Quantize(nn.Module):
28
- def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
29
- super().__init__()
30
-
31
- self.dim = dim
32
- self.n_embed = n_embed
33
- self.decay = decay
34
- self.eps = eps
35
-
36
- self.balancing_heuristic = balancing_heuristic
37
- self.codes = None
38
- self.max_codes = 64000
39
- self.codes_full = False
40
- self.new_return_order = new_return_order
41
-
42
- embed = torch.randn(dim, n_embed)
43
- self.register_buffer("embed", embed)
44
- self.register_buffer("cluster_size", torch.zeros(n_embed))
45
- self.register_buffer("embed_avg", embed.clone())
46
-
47
- def forward(self, input, return_soft_codes=False):
48
- if self.balancing_heuristic and self.codes_full:
49
- h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
50
- mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
51
- ep = self.embed.permute(1,0)
52
- ea = self.embed_avg.permute(1,0)
53
- rand_embed = torch.randn_like(ep) * mask
54
- self.embed = (ep * ~mask + rand_embed).permute(1,0)
55
- self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
56
- self.cluster_size = self.cluster_size * ~mask.squeeze()
57
- if torch.any(mask):
58
- print(f"Reset {torch.sum(mask)} embedding codes.")
59
- self.codes = None
60
- self.codes_full = False
61
-
62
- flatten = input.reshape(-1, self.dim)
63
- dist = (
64
- flatten.pow(2).sum(1, keepdim=True)
65
- - 2 * flatten @ self.embed
66
- + self.embed.pow(2).sum(0, keepdim=True)
67
- )
68
- soft_codes = -dist
69
- _, embed_ind = soft_codes.max(1)
70
- embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
71
- embed_ind = embed_ind.view(*input.shape[:-1])
72
- quantize = self.embed_code(embed_ind)
73
-
74
- if self.balancing_heuristic:
75
- if self.codes is None:
76
- self.codes = embed_ind.flatten()
77
- else:
78
- self.codes = torch.cat([self.codes, embed_ind.flatten()])
79
- if len(self.codes) > self.max_codes:
80
- self.codes = self.codes[-self.max_codes:]
81
- self.codes_full = True
82
-
83
- if self.training:
84
- embed_onehot_sum = embed_onehot.sum(0)
85
- embed_sum = flatten.transpose(0, 1) @ embed_onehot
86
-
87
- if distributed.is_initialized() and distributed.get_world_size() > 1:
88
- distributed.all_reduce(embed_onehot_sum)
89
- distributed.all_reduce(embed_sum)
90
-
91
- self.cluster_size.data.mul_(self.decay).add_(
92
- embed_onehot_sum, alpha=1 - self.decay
93
- )
94
- self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
95
- n = self.cluster_size.sum()
96
- cluster_size = (
97
- (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
98
- )
99
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
100
- self.embed.data.copy_(embed_normalized)
101
-
102
- diff = (quantize.detach() - input).pow(2).mean()
103
- quantize = input + (quantize - input).detach()
104
-
105
- if return_soft_codes:
106
- return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
107
- elif self.new_return_order:
108
- return quantize, embed_ind, diff
109
- else:
110
- return quantize, diff, embed_ind
111
-
112
- def embed_code(self, embed_id):
113
- return F.embedding(embed_id, self.embed.transpose(0, 1))
114
-
115
-
116
- # Fits a soft-discretized input to a normal-PDF across the specified dimension.
117
- # In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
118
- # values with the specified expected variance.
119
- class DiscretizationLoss(nn.Module):
120
- def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
121
- super().__init__()
122
- self.discrete_bins = discrete_bins
123
- self.dim = dim
124
- self.dist = torch.distributions.Normal(0, scale=expected_variance)
125
- if store_past > 0:
126
- self.record_past = True
127
- self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
128
- self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
129
- self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
130
- else:
131
- self.record_past = False
132
-
133
- def forward(self, x):
134
- other_dims = set(range(len(x.shape)))-set([self.dim])
135
- averaged = x.sum(dim=tuple(other_dims)) / x.sum()
136
- averaged = averaged - averaged.mean()
137
-
138
- if self.record_past:
139
- acc_count = self.accumulator.shape[0]
140
- avg = averaged.detach().clone()
141
- if self.accumulator_filled > 0:
142
- averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
143
- averaged / acc_count
144
-
145
- # Also push averaged into the accumulator.
146
- self.accumulator[self.accumulator_index] = avg
147
- self.accumulator_index += 1
148
- if self.accumulator_index >= acc_count:
149
- self.accumulator_index *= 0
150
- if self.accumulator_filled <= 0:
151
- self.accumulator_filled += 1
152
-
153
- return torch.sum(-self.dist.log_prob(averaged))
154
-
155
-
156
- class ResBlock(nn.Module):
157
- def __init__(self, chan, conv, activation):
158
- super().__init__()
159
- self.net = nn.Sequential(
160
- conv(chan, chan, 3, padding = 1),
161
- activation(),
162
- conv(chan, chan, 3, padding = 1),
163
- activation(),
164
- conv(chan, chan, 1)
165
- )
166
-
167
- def forward(self, x):
168
- return self.net(x) + x
169
-
170
-
171
- class UpsampledConv(nn.Module):
172
- def __init__(self, conv, *args, **kwargs):
173
- super().__init__()
174
- assert 'stride' in kwargs.keys()
175
- self.stride = kwargs['stride']
176
- del kwargs['stride']
177
- self.conv = conv(*args, **kwargs)
178
-
179
- def forward(self, x):
180
- up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
181
- return self.conv(up)
182
-
183
-
184
- # DiscreteVAE partially derived from lucidrains DALLE implementation
185
- # Credit: https://github.com/lucidrains/DALLE-pytorch
186
- class DiscreteVAE(nn.Module):
187
- def __init__(
188
- self,
189
- positional_dims=2,
190
- num_tokens = 512,
191
- codebook_dim = 512,
192
- num_layers = 3,
193
- num_resnet_blocks = 0,
194
- hidden_dim = 64,
195
- channels = 3,
196
- stride = 2,
197
- kernel_size = 4,
198
- use_transposed_convs = True,
199
- encoder_norm = False,
200
- activation = 'relu',
201
- smooth_l1_loss = False,
202
- straight_through = False,
203
- normalization = None, # ((0.5,) * 3, (0.5,) * 3),
204
- record_codes = False,
205
- discretization_loss_averaging_steps = 100,
206
- lr_quantizer_args = {},
207
- ):
208
- super().__init__()
209
- has_resblocks = num_resnet_blocks > 0
210
-
211
- self.num_tokens = num_tokens
212
- self.num_layers = num_layers
213
- self.straight_through = straight_through
214
- self.positional_dims = positional_dims
215
- self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
216
-
217
- assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
218
- if positional_dims == 2:
219
- conv = nn.Conv2d
220
- conv_transpose = nn.ConvTranspose2d
221
- else:
222
- conv = nn.Conv1d
223
- conv_transpose = nn.ConvTranspose1d
224
- if not use_transposed_convs:
225
- conv_transpose = functools.partial(UpsampledConv, conv)
226
-
227
- if activation == 'relu':
228
- act = nn.ReLU
229
- elif activation == 'silu':
230
- act = nn.SiLU
231
- else:
232
- assert NotImplementedError()
233
-
234
-
235
- enc_layers = []
236
- dec_layers = []
237
-
238
- if num_layers > 0:
239
- enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
240
- dec_chans = list(reversed(enc_chans))
241
-
242
- enc_chans = [channels, *enc_chans]
243
-
244
- dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
245
- dec_chans = [dec_init_chan, *dec_chans]
246
-
247
- enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
248
-
249
- pad = (kernel_size - 1) // 2
250
- for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
251
- enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
252
- if encoder_norm:
253
- enc_layers.append(nn.GroupNorm(8, enc_out))
254
- dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
255
- dec_out_chans = dec_chans[-1]
256
- innermost_dim = dec_chans[0]
257
- else:
258
- enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
259
- dec_out_chans = hidden_dim
260
- innermost_dim = hidden_dim
261
-
262
- for _ in range(num_resnet_blocks):
263
- dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
264
- enc_layers.append(ResBlock(innermost_dim, conv, act))
265
-
266
- if num_resnet_blocks > 0:
267
- dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
268
-
269
-
270
- enc_layers.append(conv(innermost_dim, codebook_dim, 1))
271
- dec_layers.append(conv(dec_out_chans, channels, 1))
272
-
273
- self.encoder = nn.Sequential(*enc_layers)
274
- self.decoder = nn.Sequential(*dec_layers)
275
-
276
- self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
277
- self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
278
-
279
- # take care of normalization within class
280
- self.normalization = normalization
281
- self.record_codes = record_codes
282
- if record_codes:
283
- self.codes = torch.zeros((1228800,), dtype=torch.long)
284
- self.code_ind = 0
285
- self.total_codes = 0
286
- self.internal_step = 0
287
-
288
- def norm(self, images):
289
- if not self.normalization is not None:
290
- return images
291
-
292
- means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
293
- arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
294
- means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
295
- images = images.clone()
296
- images.sub_(means).div_(stds)
297
- return images
298
-
299
- def get_debug_values(self, step, __):
300
- if self.record_codes and self.total_codes > 0:
301
- # Report annealing schedule
302
- return {'histogram_codes': self.codes[:self.total_codes]}
303
- else:
304
- return {}
305
-
306
- @torch.no_grad()
307
- @eval_decorator
308
- def get_codebook_indices(self, images):
309
- img = self.norm(images)
310
- logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
311
- sampled, codes, _ = self.codebook(logits)
312
- self.log_codes(codes)
313
- return codes
314
-
315
- def decode(
316
- self,
317
- img_seq
318
- ):
319
- self.log_codes(img_seq)
320
- if hasattr(self.codebook, 'embed_code'):
321
- image_embeds = self.codebook.embed_code(img_seq)
322
- else:
323
- image_embeds = F.embedding(img_seq, self.codebook.codebook)
324
- b, n, d = image_embeds.shape
325
-
326
- kwargs = {}
327
- if self.positional_dims == 1:
328
- arrange = 'b n d -> b d n'
329
- else:
330
- h = w = int(sqrt(n))
331
- arrange = 'b (h w) d -> b d h w'
332
- kwargs = {'h': h, 'w': w}
333
- image_embeds = rearrange(image_embeds, arrange, **kwargs)
334
- images = [image_embeds]
335
- for layer in self.decoder:
336
- images.append(layer(images[-1]))
337
- return images[-1], images[-2]
338
-
339
- def infer(self, img):
340
- img = self.norm(img)
341
- logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
342
- sampled, codes, commitment_loss = self.codebook(logits)
343
- return self.decode(codes)
344
-
345
- # Note: This module is not meant to be run in forward() except while training. It has special logic which performs
346
- # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
347
- # more lossy (but useful for determining network performance).
348
- def forward(
349
- self,
350
- img
351
- ):
352
- img = self.norm(img)
353
- logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
354
- sampled, codes, commitment_loss = self.codebook(logits)
355
- sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
356
-
357
- if self.training:
358
- out = sampled
359
- for d in self.decoder:
360
- out = d(out)
361
- self.log_codes(codes)
362
- else:
363
- # This is non-differentiable, but gives a better idea of how the network is actually performing.
364
- out, _ = self.decode(codes)
365
-
366
- # reconstruction loss
367
- recon_loss = self.loss_fn(img, out, reduction='none')
368
-
369
- return recon_loss, commitment_loss, out
370
-
371
- def log_codes(self, codes):
372
- # This is so we can debug the distribution of codes being learned.
373
- if self.record_codes and self.internal_step % 10 == 0:
374
- codes = codes.flatten()
375
- l = codes.shape[0]
376
- i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
377
- self.codes[i:i+l] = codes.cpu()
378
- self.code_ind = self.code_ind + l
379
- if self.code_ind >= self.codes.shape[0]:
380
- self.code_ind = 0
381
- self.total_codes += 1
382
- self.internal_step += 1
383
-
384
-
385
- if __name__ == '__main__':
386
- v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
387
- hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False)
388
- r,l,o=v(torch.randn(1,80,256))
389
- v.decode(torch.randint(0,8192,(1,256)))
390
- print(o.shape, l.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vocoder.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+ class KernelPredictor(torch.nn.Module):
8
+ ''' Kernel predictor for the location-variable convolutions'''
9
+
10
+ def __init__(
11
+ self,
12
+ cond_channels,
13
+ conv_in_channels,
14
+ conv_out_channels,
15
+ conv_layers,
16
+ conv_kernel_size=3,
17
+ kpnet_hidden_channels=64,
18
+ kpnet_conv_size=3,
19
+ kpnet_dropout=0.0,
20
+ kpnet_nonlinear_activation="LeakyReLU",
21
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
22
+ ):
23
+ '''
24
+ Args:
25
+ cond_channels (int): number of channel for the conditioning sequence,
26
+ conv_in_channels (int): number of channel for the input sequence,
27
+ conv_out_channels (int): number of channel for the output sequence,
28
+ conv_layers (int): number of layers
29
+ '''
30
+ super().__init__()
31
+
32
+ self.conv_in_channels = conv_in_channels
33
+ self.conv_out_channels = conv_out_channels
34
+ self.conv_kernel_size = conv_kernel_size
35
+ self.conv_layers = conv_layers
36
+
37
+ kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
38
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
39
+
40
+ self.input_conv = nn.Sequential(
41
+ nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
42
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
43
+ )
44
+
45
+ self.residual_convs = nn.ModuleList()
46
+ padding = (kpnet_conv_size - 1) // 2
47
+ for _ in range(3):
48
+ self.residual_convs.append(
49
+ nn.Sequential(
50
+ nn.Dropout(kpnet_dropout),
51
+ nn.utils.weight_norm(
52
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
53
+ bias=True)),
54
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
55
+ nn.utils.weight_norm(
56
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
57
+ bias=True)),
58
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
59
+ )
60
+ )
61
+ self.kernel_conv = nn.utils.weight_norm(
62
+ nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
63
+ self.bias_conv = nn.utils.weight_norm(
64
+ nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
65
+
66
+ def forward(self, c):
67
+ '''
68
+ Args:
69
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
70
+ '''
71
+ batch, _, cond_length = c.shape
72
+ c = self.input_conv(c)
73
+ for residual_conv in self.residual_convs:
74
+ residual_conv.to(c.device)
75
+ c = c + residual_conv(c)
76
+ k = self.kernel_conv(c)
77
+ b = self.bias_conv(c)
78
+ kernels = k.contiguous().view(
79
+ batch,
80
+ self.conv_layers,
81
+ self.conv_in_channels,
82
+ self.conv_out_channels,
83
+ self.conv_kernel_size,
84
+ cond_length,
85
+ )
86
+ bias = b.contiguous().view(
87
+ batch,
88
+ self.conv_layers,
89
+ self.conv_out_channels,
90
+ cond_length,
91
+ )
92
+
93
+ return kernels, bias
94
+
95
+ def remove_weight_norm(self):
96
+ nn.utils.remove_weight_norm(self.input_conv[0])
97
+ nn.utils.remove_weight_norm(self.kernel_conv)
98
+ nn.utils.remove_weight_norm(self.bias_conv)
99
+ for block in self.residual_convs:
100
+ nn.utils.remove_weight_norm(block[1])
101
+ nn.utils.remove_weight_norm(block[3])
102
+
103
+
104
+ class LVCBlock(torch.nn.Module):
105
+ '''the location-variable convolutions'''
106
+
107
+ def __init__(
108
+ self,
109
+ in_channels,
110
+ cond_channels,
111
+ stride,
112
+ dilations=[1, 3, 9, 27],
113
+ lReLU_slope=0.2,
114
+ conv_kernel_size=3,
115
+ cond_hop_length=256,
116
+ kpnet_hidden_channels=64,
117
+ kpnet_conv_size=3,
118
+ kpnet_dropout=0.0,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.cond_hop_length = cond_hop_length
123
+ self.conv_layers = len(dilations)
124
+ self.conv_kernel_size = conv_kernel_size
125
+
126
+ self.kernel_predictor = KernelPredictor(
127
+ cond_channels=cond_channels,
128
+ conv_in_channels=in_channels,
129
+ conv_out_channels=2 * in_channels,
130
+ conv_layers=len(dilations),
131
+ conv_kernel_size=conv_kernel_size,
132
+ kpnet_hidden_channels=kpnet_hidden_channels,
133
+ kpnet_conv_size=kpnet_conv_size,
134
+ kpnet_dropout=kpnet_dropout,
135
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
136
+ )
137
+
138
+ self.convt_pre = nn.Sequential(
139
+ nn.LeakyReLU(lReLU_slope),
140
+ nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
141
+ padding=stride // 2 + stride % 2, output_padding=stride % 2)),
142
+ )
143
+
144
+ self.conv_blocks = nn.ModuleList()
145
+ for dilation in dilations:
146
+ self.conv_blocks.append(
147
+ nn.Sequential(
148
+ nn.LeakyReLU(lReLU_slope),
149
+ nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
150
+ padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
151
+ nn.LeakyReLU(lReLU_slope),
152
+ )
153
+ )
154
+
155
+ def forward(self, x, c):
156
+ ''' forward propagation of the location-variable convolutions.
157
+ Args:
158
+ x (Tensor): the input sequence (batch, in_channels, in_length)
159
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
160
+
161
+ Returns:
162
+ Tensor: the output sequence (batch, in_channels, in_length)
163
+ '''
164
+ _, in_channels, _ = x.shape # (B, c_g, L')
165
+
166
+ x = self.convt_pre(x) # (B, c_g, stride * L')
167
+ kernels, bias = self.kernel_predictor(c)
168
+
169
+ for i, conv in enumerate(self.conv_blocks):
170
+ output = conv(x) # (B, c_g, stride * L')
171
+
172
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
173
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
174
+
175
+ output = self.location_variable_convolution(output, k, b,
176
+ hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
177
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
178
+ output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
179
+
180
+ return x
181
+
182
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
183
+ ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
184
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
185
+ Args:
186
+ x (Tensor): the input sequence (batch, in_channels, in_length).
187
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
188
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
189
+ dilation (int): the dilation of convolution.
190
+ hop_size (int): the hop_size of the conditioning sequence.
191
+ Returns:
192
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
193
+ '''
194
+ batch, _, in_length = x.shape
195
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
196
+ assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
197
+
198
+ padding = dilation * int((kernel_size - 1) / 2)
199
+ x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
200
+ x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
201
+
202
+ if hop_size < dilation:
203
+ x = F.pad(x, (0, dilation), 'constant', 0)
204
+ x = x.unfold(3, dilation,
205
+ dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
206
+ x = x[:, :, :, :, :hop_size]
207
+ x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
208
+ x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
209
+
210
+ o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
211
+ o = o.to(memory_format=torch.channels_last_3d)
212
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
213
+ o = o + bias
214
+ o = o.contiguous().view(batch, out_channels, -1)
215
+
216
+ return o
217
+
218
+ def remove_weight_norm(self):
219
+ self.kernel_predictor.remove_weight_norm()
220
+ nn.utils.remove_weight_norm(self.convt_pre[1])
221
+ for block in self.conv_blocks:
222
+ nn.utils.remove_weight_norm(block[1])
223
+
224
+
225
+ class UnivNetGenerator(nn.Module):
226
+ """UnivNet Generator"""
227
+
228
+ def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
229
+ # Below are MEL configurations options that this generator requires.
230
+ hop_length=256, n_mel_channels=100):
231
+ super(UnivNetGenerator, self).__init__()
232
+ self.mel_channel = n_mel_channels
233
+ self.noise_dim = noise_dim
234
+ self.hop_length = hop_length
235
+ channel_size = channel_size
236
+ kpnet_conv_size = kpnet_conv_size
237
+
238
+ self.res_stack = nn.ModuleList()
239
+ hop_length = 1
240
+ for stride in strides:
241
+ hop_length = stride * hop_length
242
+ self.res_stack.append(
243
+ LVCBlock(
244
+ channel_size,
245
+ n_mel_channels,
246
+ stride=stride,
247
+ dilations=dilations,
248
+ lReLU_slope=lReLU_slope,
249
+ cond_hop_length=hop_length,
250
+ kpnet_conv_size=kpnet_conv_size
251
+ )
252
+ )
253
+
254
+ self.conv_pre = \
255
+ nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
256
+
257
+ self.conv_post = nn.Sequential(
258
+ nn.LeakyReLU(lReLU_slope),
259
+ nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
260
+ nn.Tanh(),
261
+ )
262
+
263
+ def forward(self, c, z):
264
+ '''
265
+ Args:
266
+ c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
267
+ z (Tensor): the noise sequence (batch, noise_dim, in_length)
268
+
269
+ '''
270
+ z = self.conv_pre(z) # (B, c_g, L)
271
+
272
+ for res_block in self.res_stack:
273
+ res_block.to(z.device)
274
+ z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
275
+
276
+ z = self.conv_post(z) # (B, 1, L * 256)
277
+
278
+ return z
279
+
280
+ def eval(self, inference=False):
281
+ super(UnivNetGenerator, self).eval()
282
+ # don't remove weight norm while validation in training loop
283
+ if inference:
284
+ self.remove_weight_norm()
285
+
286
+ def remove_weight_norm(self):
287
+ print('Removing weight norm...')
288
+
289
+ nn.utils.remove_weight_norm(self.conv_pre)
290
+
291
+ for layer in self.conv_post:
292
+ if len(layer.state_dict()) != 0:
293
+ nn.utils.remove_weight_norm(layer)
294
+
295
+ for res_block in self.res_stack:
296
+ res_block.remove_weight_norm()
297
+
298
+ def inference(self, c, z=None):
299
+ # pad input mel with zeros to cut artifact
300
+ # see https://github.com/seungwonpark/melgan/issues/8
301
+ zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
302
+ mel = torch.cat((c, zero), dim=2)
303
+
304
+ if z is None:
305
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
306
+
307
+ audio = self.forward(mel, z)
308
+ audio = audio[:, :, :-(self.hop_length * 10)]
309
+ audio = audio.clamp(min=-1, max=1)
310
+ return audio
311
+
312
+
313
+ if __name__ == '__main__':
314
+ model = UnivNetGenerator()
315
+
316
+ c = torch.randn(3, 100, 10)
317
+ z = torch.randn(3, 64, 10)
318
+ print(c.shape)
319
+
320
+ y = model(c, z)
321
+ print(y.shape)
322
+ assert y.shape == torch.Size([3, 1, 2560])
323
+
324
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
325
+ print(pytorch_total_params)
requirements.txt CHANGED
@@ -6,4 +6,5 @@ tokenizers
6
  inflect
7
  progressbar
8
  einops
9
- unidecode
 
 
6
  inflect
7
  progressbar
8
  einops
9
+ unidecode
10
+ x-transformers
utils/audio.py CHANGED
@@ -3,6 +3,8 @@ import torchaudio
3
  import numpy as np
4
  from scipy.io.wavfile import read
5
 
 
 
6
 
7
  def load_wav_to_torch(full_path):
8
  sampling_rate, data = read(full_path)
@@ -43,4 +45,86 @@ def load_audio(audiopath, sampling_rate):
43
  print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
44
  audio.clip_(-1, 1)
45
 
46
- return audio.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from scipy.io.wavfile import read
5
 
6
+ from utils.stft import STFT
7
+
8
 
9
  def load_wav_to_torch(full_path):
10
  sampling_rate, data = read(full_path)
 
45
  print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
46
  audio.clip_(-1, 1)
47
 
48
+ return audio.unsqueeze(0)
49
+
50
+
51
+ TACOTRON_MEL_MAX = 2.3143386840820312
52
+ TACOTRON_MEL_MIN = -11.512925148010254
53
+
54
+
55
+ def denormalize_tacotron_mel(norm_mel):
56
+ return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
57
+
58
+
59
+ def normalize_tacotron_mel(mel):
60
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
61
+
62
+
63
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
64
+ """
65
+ PARAMS
66
+ ------
67
+ C: compression factor
68
+ """
69
+ return torch.log(torch.clamp(x, min=clip_val) * C)
70
+
71
+
72
+ def dynamic_range_decompression(x, C=1):
73
+ """
74
+ PARAMS
75
+ ------
76
+ C: compression factor used to compress
77
+ """
78
+ return torch.exp(x) / C
79
+
80
+
81
+ class TacotronSTFT(torch.nn.Module):
82
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
83
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
84
+ mel_fmax=8000.0):
85
+ super(TacotronSTFT, self).__init__()
86
+ self.n_mel_channels = n_mel_channels
87
+ self.sampling_rate = sampling_rate
88
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
89
+ from librosa.filters import mel as librosa_mel_fn
90
+ mel_basis = librosa_mel_fn(
91
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
92
+ mel_basis = torch.from_numpy(mel_basis).float()
93
+ self.register_buffer('mel_basis', mel_basis)
94
+
95
+ def spectral_normalize(self, magnitudes):
96
+ output = dynamic_range_compression(magnitudes)
97
+ return output
98
+
99
+ def spectral_de_normalize(self, magnitudes):
100
+ output = dynamic_range_decompression(magnitudes)
101
+ return output
102
+
103
+ def mel_spectrogram(self, y):
104
+ """Computes mel-spectrograms from a batch of waves
105
+ PARAMS
106
+ ------
107
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
108
+
109
+ RETURNS
110
+ -------
111
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
112
+ """
113
+ assert(torch.min(y.data) >= -10)
114
+ assert(torch.max(y.data) <= 10)
115
+ y = torch.clip(y, min=-1, max=1)
116
+
117
+ magnitudes, phases = self.stft_fn.transform(y)
118
+ magnitudes = magnitudes.data
119
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
120
+ mel_output = self.spectral_normalize(mel_output)
121
+ return mel_output
122
+
123
+
124
+ def wav_to_univnet_mel(wav, do_normalization=False):
125
+ stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
126
+ stft = stft.cuda()
127
+ mel = stft.mel_spectrogram(wav)
128
+ if do_normalization:
129
+ mel = normalize_tacotron_mel(mel)
130
+ return mel
utils/diffusion.py CHANGED
@@ -197,11 +197,17 @@ class GaussianDiffusion:
197
  model_var_type,
198
  loss_type,
199
  rescale_timesteps=False,
 
 
 
200
  ):
201
  self.model_mean_type = ModelMeanType(model_mean_type)
202
  self.model_var_type = ModelVarType(model_var_type)
203
  self.loss_type = LossType(loss_type)
204
  self.rescale_timesteps = rescale_timesteps
 
 
 
205
 
206
  # Use float64 for accuracy.
207
  betas = np.array(betas, dtype=np.float64)
@@ -332,10 +338,14 @@ class GaussianDiffusion:
332
  B, C = x.shape[:2]
333
  assert t.shape == (B,)
334
  model_output = model(x, self._scale_timesteps(t), **model_kwargs)
 
 
335
 
336
  if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
337
  assert model_output.shape == (B, C * 2, *x.shape[2:])
338
  model_output, model_var_values = th.split(model_output, C, dim=1)
 
 
339
  if self.model_var_type == ModelVarType.LEARNED:
340
  model_log_variance = model_var_values
341
  model_variance = th.exp(model_log_variance)
@@ -364,6 +374,14 @@ class GaussianDiffusion:
364
  model_variance = _extract_into_tensor(model_variance, t, x.shape)
365
  model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
366
 
 
 
 
 
 
 
 
 
367
  def process_xstart(x):
368
  if denoised_fn is not None:
369
  x = denoised_fn(x)
 
197
  model_var_type,
198
  loss_type,
199
  rescale_timesteps=False,
200
+ conditioning_free=False,
201
+ conditioning_free_k=1,
202
+ ramp_conditioning_free=True,
203
  ):
204
  self.model_mean_type = ModelMeanType(model_mean_type)
205
  self.model_var_type = ModelVarType(model_var_type)
206
  self.loss_type = LossType(loss_type)
207
  self.rescale_timesteps = rescale_timesteps
208
+ self.conditioning_free = conditioning_free
209
+ self.conditioning_free_k = conditioning_free_k
210
+ self.ramp_conditioning_free = ramp_conditioning_free
211
 
212
  # Use float64 for accuracy.
213
  betas = np.array(betas, dtype=np.float64)
 
338
  B, C = x.shape[:2]
339
  assert t.shape == (B,)
340
  model_output = model(x, self._scale_timesteps(t), **model_kwargs)
341
+ if self.conditioning_free:
342
+ model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
343
 
344
  if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
345
  assert model_output.shape == (B, C * 2, *x.shape[2:])
346
  model_output, model_var_values = th.split(model_output, C, dim=1)
347
+ if self.conditioning_free:
348
+ model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
349
  if self.model_var_type == ModelVarType.LEARNED:
350
  model_log_variance = model_var_values
351
  model_variance = th.exp(model_log_variance)
 
374
  model_variance = _extract_into_tensor(model_variance, t, x.shape)
375
  model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
376
 
377
+ if self.conditioning_free:
378
+ if self.ramp_conditioning_free:
379
+ assert t.shape[0] == 1 # This should only be used in inference.
380
+ cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
381
+ else:
382
+ cfk = self.conditioning_free_k
383
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
384
+
385
  def process_xstart(x):
386
  if denoised_fn is not None:
387
  x = denoised_fn(x)
utils/stft.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ import librosa.util as librosa_util
40
+
41
+
42
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
43
+ n_fft=800, dtype=np.float32, norm=None):
44
+ """
45
+ # from librosa 0.6
46
+ Compute the sum-square envelope of a window function at a given hop length.
47
+
48
+ This is used to estimate modulation effects induced by windowing
49
+ observations in short-time fourier transforms.
50
+
51
+ Parameters
52
+ ----------
53
+ window : string, tuple, number, callable, or list-like
54
+ Window specification, as in `get_window`
55
+
56
+ n_frames : int > 0
57
+ The number of analysis frames
58
+
59
+ hop_length : int > 0
60
+ The number of samples to advance between frames
61
+
62
+ win_length : [optional]
63
+ The length of the window function. By default, this matches `n_fft`.
64
+
65
+ n_fft : int > 0
66
+ The length of each analysis frame.
67
+
68
+ dtype : np.dtype
69
+ The data type of the output
70
+
71
+ Returns
72
+ -------
73
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
74
+ The sum-squared envelope of the window function
75
+ """
76
+ if win_length is None:
77
+ win_length = n_fft
78
+
79
+ n = n_fft + hop_length * (n_frames - 1)
80
+ x = np.zeros(n, dtype=dtype)
81
+
82
+ # Compute the squared window at the desired length
83
+ win_sq = get_window(window, win_length, fftbins=True)
84
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
85
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
86
+
87
+ # Fill the envelope
88
+ for i in range(n_frames):
89
+ sample = i * hop_length
90
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
91
+ return x
92
+
93
+
94
+ class STFT(torch.nn.Module):
95
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
96
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
97
+ window='hann'):
98
+ super(STFT, self).__init__()
99
+ self.filter_length = filter_length
100
+ self.hop_length = hop_length
101
+ self.win_length = win_length
102
+ self.window = window
103
+ self.forward_transform = None
104
+ scale = self.filter_length / self.hop_length
105
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
106
+
107
+ cutoff = int((self.filter_length / 2 + 1))
108
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
109
+ np.imag(fourier_basis[:cutoff, :])])
110
+
111
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
112
+ inverse_basis = torch.FloatTensor(
113
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
114
+
115
+ if window is not None:
116
+ assert(filter_length >= win_length)
117
+ # get window and zero center pad it to filter_length
118
+ fft_window = get_window(window, win_length, fftbins=True)
119
+ fft_window = pad_center(fft_window, filter_length)
120
+ fft_window = torch.from_numpy(fft_window).float()
121
+
122
+ # window the bases
123
+ forward_basis *= fft_window
124
+ inverse_basis *= fft_window
125
+
126
+ self.register_buffer('forward_basis', forward_basis.float())
127
+ self.register_buffer('inverse_basis', inverse_basis.float())
128
+
129
+ def transform(self, input_data):
130
+ num_batches = input_data.size(0)
131
+ num_samples = input_data.size(1)
132
+
133
+ self.num_samples = num_samples
134
+
135
+ # similar to librosa, reflect-pad the input
136
+ input_data = input_data.view(num_batches, 1, num_samples)
137
+ input_data = F.pad(
138
+ input_data.unsqueeze(1),
139
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
140
+ mode='reflect')
141
+ input_data = input_data.squeeze(1)
142
+
143
+ forward_transform = F.conv1d(
144
+ input_data,
145
+ Variable(self.forward_basis, requires_grad=False),
146
+ stride=self.hop_length,
147
+ padding=0)
148
+
149
+ cutoff = int((self.filter_length / 2) + 1)
150
+ real_part = forward_transform[:, :cutoff, :]
151
+ imag_part = forward_transform[:, cutoff:, :]
152
+
153
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
154
+ phase = torch.autograd.Variable(
155
+ torch.atan2(imag_part.data, real_part.data))
156
+
157
+ return magnitude, phase
158
+
159
+ def inverse(self, magnitude, phase):
160
+ recombine_magnitude_phase = torch.cat(
161
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
162
+
163
+ inverse_transform = F.conv_transpose1d(
164
+ recombine_magnitude_phase,
165
+ Variable(self.inverse_basis, requires_grad=False),
166
+ stride=self.hop_length,
167
+ padding=0)
168
+
169
+ if self.window is not None:
170
+ window_sum = window_sumsquare(
171
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
172
+ win_length=self.win_length, n_fft=self.filter_length,
173
+ dtype=np.float32)
174
+ # remove modulation effects
175
+ approx_nonzero_indices = torch.from_numpy(
176
+ np.where(window_sum > tiny(window_sum))[0])
177
+ window_sum = torch.autograd.Variable(
178
+ torch.from_numpy(window_sum), requires_grad=False)
179
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
180
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
181
+
182
+ # scale by hop ratio
183
+ inverse_transform *= float(self.filter_length) / self.hop_length
184
+
185
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
186
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
187
+
188
+ return inverse_transform
189
+
190
+ def forward(self, input_data):
191
+ self.magnitude, self.phase = self.transform(input_data)
192
+ reconstruction = self.inverse(self.magnitude, self.phase)
193
+ return reconstruction