jbetker commited on
Commit
33e4bc7
1 Parent(s): 9043dde

integrate new autoregressive model and fix new diffusion bug

Browse files
api.py CHANGED
@@ -117,13 +117,14 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
117
  cond_mels.append(cond_mel)
118
  cond_mels = torch.stack(cond_mels, dim=1)
119
 
120
- output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4)
121
- precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False)
 
122
 
123
  noise = torch.randn(output_shape, device=mel_codes.device) * temperature
124
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
125
  model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
126
- return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4]
127
 
128
 
129
  class TextToSpeech:
 
117
  cond_mels.append(cond_mel)
118
  cond_mels = torch.stack(cond_mels, dim=1)
119
 
120
+ output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
121
+ output_shape = (mel_codes.shape[0], 100, output_seq_len)
122
+ precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
123
 
124
  noise = torch.randn(output_shape, device=mel_codes.device) * temperature
125
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
126
  model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
127
+ return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
128
 
129
 
130
  class TextToSpeech:
api_new_autoregressive.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from urllib import request
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ import progressbar
10
+ import ocotillo
11
+
12
+ from models.diffusion_decoder import DiffusionTts
13
+ from models.autoregressive import UnifiedVoice
14
+ from tqdm import tqdm
15
+
16
+ from models.arch_util import TorchMelSpectrogram
17
+ from models.new_autoregressive import AutoregressiveCodegen
18
+ from models.text_voice_clip import VoiceCLIP
19
+ from models.vocoder import UnivNetGenerator
20
+ from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
21
+ from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
22
+ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
23
+
24
+
25
+ pbar = None
26
+ def download_models():
27
+ MODELS = {
28
+ 'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
29
+ 'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
30
+ 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
31
+ }
32
+ os.makedirs('.models', exist_ok=True)
33
+ def show_progress(block_num, block_size, total_size):
34
+ global pbar
35
+ if pbar is None:
36
+ pbar = progressbar.ProgressBar(maxval=total_size)
37
+ pbar.start()
38
+
39
+ downloaded = block_num * block_size
40
+ if downloaded < total_size:
41
+ pbar.update(downloaded)
42
+ else:
43
+ pbar.finish()
44
+ pbar = None
45
+ for model_name, url in MODELS.items():
46
+ if os.path.exists(f'.models/{model_name}'):
47
+ continue
48
+ print(f'Downloading {model_name} from {url}...')
49
+ request.urlretrieve(url, f'.models/{model_name}', show_progress)
50
+ print('Done.')
51
+
52
+
53
+ def pad_or_truncate(t, length):
54
+ if t.shape[-1] == length:
55
+ return t
56
+ elif t.shape[-1] < length:
57
+ return F.pad(t, (0, length-t.shape[-1]))
58
+ else:
59
+ return t[..., :length]
60
+
61
+
62
+ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
63
+ """
64
+ Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
65
+ """
66
+ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
67
+ model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
68
+ conditioning_free=cond_free, conditioning_free_k=cond_free_k)
69
+
70
+
71
+ def load_conditioning(clip, cond_length=132300):
72
+ gap = clip.shape[-1] - cond_length
73
+ if gap < 0:
74
+ clip = F.pad(clip, pad=(0, abs(gap)))
75
+ elif gap > 0:
76
+ rand_start = random.randint(0, gap)
77
+ clip = clip[:, rand_start:rand_start + cond_length]
78
+ mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
79
+ return mel_clip.unsqueeze(0).cuda()
80
+
81
+
82
+ def fix_autoregressive_output(codes, stop_token):
83
+ """
84
+ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
85
+ trained on and what the autoregressive code generator creates (which has no padding or end).
86
+ This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
87
+ a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
88
+ and copying out the last few codes.
89
+
90
+ Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
91
+ """
92
+ # Strip off the autoregressive stop token and add padding.
93
+ stop_token_indices = (codes == stop_token).nonzero()
94
+ if len(stop_token_indices) == 0:
95
+ print("No stop tokens found, enjoy that output of yours!")
96
+ return codes
97
+ else:
98
+ codes[stop_token_indices] = 83
99
+ stm = stop_token_indices.min().item()
100
+ codes[stm:] = 83
101
+ if stm - 3 < codes.shape[0]:
102
+ codes[-3] = 45
103
+ codes[-2] = 45
104
+ codes[-1] = 248
105
+
106
+ return codes
107
+
108
+
109
+ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
110
+ """
111
+ Uses the specified diffusion model to convert discrete codes into a spectrogram.
112
+ """
113
+ with torch.no_grad():
114
+ cond_mels = []
115
+ for sample in conditioning_samples:
116
+ sample = pad_or_truncate(sample, 102400)
117
+ cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
118
+ cond_mels.append(cond_mel)
119
+ cond_mels = torch.stack(cond_mels, dim=1)
120
+
121
+ output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
122
+ output_shape = (mel_codes.shape[0], 100, output_seq_len)
123
+ precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
124
+
125
+ noise = torch.randn(output_shape, device=mel_codes.device) * temperature
126
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
127
+ model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
128
+ return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
129
+
130
+
131
+ class TextToSpeech:
132
+ def __init__(self, autoregressive_batch_size=32):
133
+ self.autoregressive_batch_size = autoregressive_batch_size
134
+ self.tokenizer = VoiceBpeTokenizer()
135
+ download_models()
136
+
137
+ self.autoregressive = AutoregressiveCodegen(512, 12).cpu().eval()
138
+ self.autoregressive.load_state_dict(torch.load('D:\\dlas\\experiments\\train_autoregressive_codegen\\models\\23000_codegen_ema.pth'))
139
+
140
+ self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
141
+ text_seq_len=350, text_heads=8,
142
+ num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
143
+ use_xformers=True).cpu().eval()
144
+ self.clip.load_state_dict(torch.load('.models/clip.pth'))
145
+
146
+ self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
147
+ in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
148
+ layer_drop=0, unconditioned_percentage=0).cpu().eval()
149
+ self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
150
+
151
+ self.vocoder = UnivNetGenerator().cpu()
152
+ self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
153
+ self.vocoder.eval(inference=True)
154
+
155
+ def tts(self, text, voice_samples, k=1,
156
+ # autoregressive generation parameters follow
157
+ num_autoregressive_samples=512, temperature=.5, length_penalty=2, repetition_penalty=2.0, top_p=.5,
158
+ typical_sampling=False, typical_mass=.9,
159
+ # diffusion generation parameters follow
160
+ diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=.7,):
161
+ text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
162
+ text = F.pad(text, (0, 1)) # This may not be necessary.
163
+
164
+ conds = []
165
+ if not isinstance(voice_samples, list):
166
+ voice_samples = [voice_samples]
167
+ for vs in voice_samples:
168
+ conds.append(load_conditioning(vs))
169
+ conds = torch.stack(conds, dim=1)
170
+
171
+ diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
172
+
173
+ with torch.no_grad():
174
+ samples = []
175
+ num_batches = num_autoregressive_samples // self.autoregressive_batch_size
176
+ stop_mel_token = self.autoregressive.STOP_TOKEN
177
+ self.autoregressive = self.autoregressive.cuda()
178
+ for _ in tqdm(range(num_batches)):
179
+ codes = self.autoregressive.generate(conds, text,
180
+ do_sample=True,
181
+ top_p=top_p,
182
+ temperature=temperature,
183
+ num_return_sequences=self.autoregressive_batch_size,
184
+ length_penalty=length_penalty,
185
+ repetition_penalty=repetition_penalty,
186
+ typical_sampling=typical_sampling,
187
+ typical_mass=typical_mass)
188
+ padding_needed = 250 - codes.shape[1]
189
+ codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
190
+ samples.append(codes)
191
+ #self.autoregressive = self.autoregressive.cpu()
192
+
193
+ clip_results = []
194
+ self.clip = self.clip.cuda()
195
+ for batch in samples:
196
+ for i in range(batch.shape[0]):
197
+ batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
198
+ bad_toks = batch >= 8192
199
+ batch = batch * bad_toks.logical_not()
200
+ clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
201
+ clip_results = torch.cat(clip_results, dim=0)
202
+ samples = torch.cat(samples, dim=0)
203
+ best_results = samples[torch.topk(clip_results, k=k).indices]
204
+ self.clip = self.clip.cpu()
205
+ del samples
206
+
207
+ print("Performing vocoding..")
208
+ wav_candidates = []
209
+ self.diffusion = self.diffusion.cuda()
210
+ self.vocoder = self.vocoder.cuda()
211
+ for b in range(best_results.shape[0]):
212
+ code = best_results[b].unsqueeze(0)
213
+ mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
214
+ wav = self.vocoder.inference(mel)
215
+ wav_candidates.append(wav.cpu())
216
+ self.diffusion = self.diffusion.cpu()
217
+ self.vocoder = self.vocoder.cpu()
218
+
219
+ if len(wav_candidates) > 1:
220
+ return wav_candidates
221
+ return wav_candidates[0]
222
+
223
+ def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
224
+ """
225
+ Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
226
+ TODO: finish this function
227
+ :param wav_candidates:
228
+ :return:
229
+ """
230
+ transcriber = ocotillo.Transcriber(on_cuda=True)
231
+ transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
232
+ best = 99999999
233
+ for i, transcription in enumerate(transcriptions):
234
+ dist = lev_distance(transcription, args.text.lower())
235
+ if dist < best:
236
+ best = dist
237
+ best_codes = corresponding_codes[i].unsqueeze(0)
238
+ best_wav = wav_candidates[i]
239
+ del transcriber
240
+ torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
241
+
242
+ # Perform diffusion again with the high-quality diffuser.
243
+ mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
244
+ wav = vocoder.inference(mel)
245
+ torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
do_tts.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
  import torchaudio
7
 
8
- from api import TextToSpeech, load_conditioning
9
  from utils.audio import load_audio
10
  from utils.tokenizer import VoiceBpeTokenizer
11
 
@@ -28,7 +28,7 @@ if __name__ == '__main__':
28
  parser = argparse.ArgumentParser()
29
  parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
30
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
31
- parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
32
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
33
  parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
34
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
 
5
  import torch.nn.functional as F
6
  import torchaudio
7
 
8
+ from api_new_autoregressive import TextToSpeech, load_conditioning
9
  from utils.audio import load_audio
10
  from utils.tokenizer import VoiceBpeTokenizer
11
 
 
28
  parser = argparse.ArgumentParser()
29
  parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
30
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
31
+ parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
32
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
33
  parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
34
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
models/diffusion_decoder.py CHANGED
@@ -212,7 +212,7 @@ class DiffusionTts(nn.Module):
212
  }
213
  return groups
214
 
215
- def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
216
  # Shuffle aligned_latent to BxCxS format
217
  if is_latent(aligned_conditioning):
218
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
@@ -227,7 +227,7 @@ class DiffusionTts(nn.Module):
227
  cond_emb = conds.mean(dim=-1)
228
  cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
229
  if is_latent(aligned_conditioning):
230
- code_emb = self.latent_converter(aligned_conditioning)
231
  else:
232
  code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
233
  code_emb = self.code_converter(code_emb)
@@ -240,7 +240,7 @@ class DiffusionTts(nn.Module):
240
  device=code_emb.device) < self.unconditioned_percentage
241
  code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
242
  code_emb)
243
- expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
244
 
245
  if not return_code_pred:
246
  return expanded_code_emb
@@ -250,7 +250,6 @@ class DiffusionTts(nn.Module):
250
  mel_pred = mel_pred * unconditioned_batches.logical_not()
251
  return expanded_code_emb, mel_pred
252
 
253
-
254
  def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
255
  """
256
  Apply the model to an input batch.
@@ -275,11 +274,12 @@ class DiffusionTts(nn.Module):
275
  if precomputed_aligned_embeddings is not None:
276
  code_emb = precomputed_aligned_embeddings
277
  else:
278
- code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
279
  if is_latent(aligned_conditioning):
280
  unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
281
  else:
282
  unused_params.extend(list(self.latent_converter.parameters()))
 
283
  unused_params.append(self.unconditioned_embedding)
284
 
285
  time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
212
  }
213
  return groups
214
 
215
+ def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
216
  # Shuffle aligned_latent to BxCxS format
217
  if is_latent(aligned_conditioning):
218
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
 
227
  cond_emb = conds.mean(dim=-1)
228
  cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
229
  if is_latent(aligned_conditioning):
230
+ code_emb = self.autoregressive_latent_converter(aligned_conditioning)
231
  else:
232
  code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
233
  code_emb = self.code_converter(code_emb)
 
240
  device=code_emb.device) < self.unconditioned_percentage
241
  code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
242
  code_emb)
243
+ expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
244
 
245
  if not return_code_pred:
246
  return expanded_code_emb
 
250
  mel_pred = mel_pred * unconditioned_batches.logical_not()
251
  return expanded_code_emb, mel_pred
252
 
 
253
  def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
254
  """
255
  Apply the model to an input batch.
 
274
  if precomputed_aligned_embeddings is not None:
275
  code_emb = precomputed_aligned_embeddings
276
  else:
277
+ code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
278
  if is_latent(aligned_conditioning):
279
  unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
280
  else:
281
  unused_params.extend(list(self.latent_converter.parameters()))
282
+
283
  unused_params.append(self.unconditioned_embedding)
284
 
285
  time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
models/new_autoregressive.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2PreTrainedModel, GPT2Config
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from x_transformers import TransformerWrapper, Encoder, Decoder
9
+
10
+ from models.arch_util import AttentionBlock
11
+
12
+
13
+ class InferenceModel(GPT2PreTrainedModel):
14
+ """
15
+ Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
16
+ this transformer.
17
+ """
18
+ def __init__(self, model):
19
+ super().__init__(GPT2Config())
20
+ self.transformer = model
21
+ self.context = None
22
+
23
+ def parallelize(self, device_map=None):
24
+ # Not implemented.
25
+ pass
26
+
27
+ def deparallelize(self):
28
+ # Not implemented.
29
+ pass
30
+
31
+ def get_output_embeddings(self):
32
+ assert False, "Unsupported operation."
33
+
34
+ def set_output_embeddings(self, new_embeddings):
35
+ assert False, "Unsupported operation."
36
+
37
+ def store_context(self, context):
38
+ self.context = context
39
+
40
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
41
+ token_type_ids = kwargs.get("token_type_ids", None)
42
+ # only last token for inputs_ids if past is defined in kwargs
43
+ if past:
44
+ input_ids = input_ids[:, -1].unsqueeze(-1)
45
+ if token_type_ids is not None:
46
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
47
+
48
+ attention_mask = kwargs.get("attention_mask", None)
49
+ position_ids = kwargs.get("position_ids", None)
50
+
51
+ if attention_mask is not None and position_ids is None:
52
+ # create position_ids on the fly for batch generation
53
+ position_ids = attention_mask.long().cumsum(-1) - 1
54
+ position_ids.masked_fill_(attention_mask == 0, 1)
55
+ if past:
56
+ position_ids = position_ids[:, -1].unsqueeze(-1)
57
+ else:
58
+ position_ids = None
59
+ return {
60
+ "input_ids": input_ids,
61
+ "past_key_values": past,
62
+ "use_cache": kwargs.get("use_cache"),
63
+ "position_ids": position_ids,
64
+ "attention_mask": attention_mask,
65
+ "token_type_ids": token_type_ids,
66
+ }
67
+
68
+ def forward(
69
+ self,
70
+ input_ids=None,
71
+ past_key_values=None,
72
+ attention_mask=None,
73
+ token_type_ids=None,
74
+ position_ids=None,
75
+ head_mask=None,
76
+ inputs_embeds=None,
77
+ encoder_hidden_states=None,
78
+ encoder_attention_mask=None,
79
+ labels=None,
80
+ use_cache=None,
81
+ output_attentions=None,
82
+ output_hidden_states=None,
83
+ return_dict=None,
84
+ ):
85
+ assert self.context is not None
86
+ assert inputs_embeds is None # Not supported by this inference model.
87
+ assert labels is None # Training not supported by this inference model.
88
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
89
+
90
+ hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
91
+ logits = self.transformer.decoder.transformer.to_logits(hidden_states)
92
+
93
+ if not return_dict:
94
+ return (logits, )
95
+
96
+ return CausalLMOutputWithCrossAttentions(
97
+ loss=None,
98
+ logits=logits,
99
+ past_key_values=None,
100
+ hidden_states=hidden_states,
101
+ attentions=None,
102
+ cross_attentions=None,
103
+ )
104
+
105
+ @staticmethod
106
+ def _reorder_cache(past, beam_idx):
107
+ """
108
+ This function is used to re-order the :obj:`past_key_values` cache if
109
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
110
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
111
+ """
112
+ return tuple(
113
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
114
+ for layer_past in past
115
+ )
116
+
117
+
118
+ class ResBlock(nn.Module):
119
+ """
120
+ Basic residual convolutional block that uses GroupNorm.
121
+ """
122
+ def __init__(self, chan):
123
+ super().__init__()
124
+ self.net = nn.Sequential(
125
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
126
+ nn.GroupNorm(chan//8, chan),
127
+ nn.ReLU(),
128
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
129
+ nn.GroupNorm(chan//8, chan)
130
+ )
131
+
132
+ def forward(self, x):
133
+ return F.relu(self.net(x) + x)
134
+
135
+
136
+ class ConditioningEncoder(nn.Module):
137
+ def __init__(self,
138
+ spec_dim,
139
+ embedding_dim,
140
+ attn_blocks=6,
141
+ num_attn_heads=4,
142
+ do_checkpointing=False):
143
+ super().__init__()
144
+ attn = []
145
+ self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
146
+ nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
147
+ ResBlock(embedding_dim//2),
148
+ nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
149
+ for a in range(attn_blocks):
150
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
151
+ self.attn = nn.Sequential(*attn)
152
+ self.dim = embedding_dim
153
+
154
+ def forward(self, x):
155
+ h = self.init(x)
156
+ h = self.attn(h)
157
+ return h.mean(dim=2)
158
+
159
+
160
+ class CheckpointedLayer(nn.Module):
161
+ """
162
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
163
+ checkpoint for all other args.
164
+ """
165
+ def __init__(self, wrap):
166
+ super().__init__()
167
+ self.wrap = wrap
168
+
169
+ def forward(self, x, *args, **kwargs):
170
+ for k, v in kwargs.items():
171
+ assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
172
+ partial = functools.partial(self.wrap, **kwargs)
173
+ return torch.utils.checkpoint.checkpoint(partial, x, *args)
174
+
175
+
176
+ class CheckpointedXTransformerWrapper(nn.Module):
177
+ """
178
+ Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
179
+ """
180
+ def __init__(self, checkpoint=True, **xtransformer_kwargs):
181
+ super().__init__()
182
+ self.transformer = TransformerWrapper(**xtransformer_kwargs)
183
+
184
+ if not checkpoint:
185
+ return
186
+ for i in range(len(self.transformer.attn_layers.layers)):
187
+ n, b, r = self.transformer.attn_layers.layers[i]
188
+ self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
189
+
190
+ def forward(self, x, **kwargs):
191
+ return self.transformer(x, **kwargs)
192
+
193
+
194
+ class AutoregressiveCodegen(nn.Module):
195
+ def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
196
+ max_mel_tokens=4000, dropout=.1):
197
+ super().__init__()
198
+
199
+ self.START_TOKEN=8192
200
+ self.STOP_TOKEN=8193
201
+ self.max_mel_tokens = max_mel_tokens
202
+ self.minicoder = ConditioningEncoder(80, model_dim, do_checkpointing=False)
203
+ self.encoder = CheckpointedXTransformerWrapper(
204
+ num_tokens=num_text_tokens,
205
+ max_seq_len=max_text_tokens,
206
+ attn_layers = Encoder(
207
+ depth=depth//2,
208
+ heads=model_dim//64,
209
+ dim=model_dim,
210
+ attn_dropout=dropout,
211
+ ff_dropout=dropout,
212
+ use_rmsnorm=True,
213
+ ff_glu=True,
214
+ ff_mult=1,
215
+ rotary_pos_emb=True,
216
+ rel_pos_bias=True,
217
+ ))
218
+ self.decoder = CheckpointedXTransformerWrapper(
219
+ num_tokens=num_mel_tokens,
220
+ max_seq_len=max_mel_tokens,
221
+ attn_layers=Decoder(
222
+ depth=depth,
223
+ heads=model_dim//64,
224
+ dim=model_dim,
225
+ attn_dropout=dropout,
226
+ ff_dropout=dropout,
227
+ use_rmsnorm=True,
228
+ ff_glu=True,
229
+ ff_mult=1,
230
+ rotary_pos_emb=True,
231
+ rel_pos_bias=True,
232
+ cross_attend=True,
233
+ ))
234
+
235
+ def get_grad_norm_parameter_groups(self):
236
+ return {
237
+ 'encoder': list(self.encoder.parameters()),
238
+ 'decoder': list(self.decoder.parameters()),
239
+ 'minicoder': list(self.minicoder.parameters()),
240
+ }
241
+
242
+ def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
243
+ # Format mel_codes with a stop token on the end.
244
+ mel_lengths = wav_lengths // 1024 + 1
245
+ for b in range(mel_codes.shape[0]):
246
+ mel_codes[b, mel_lengths[b]:] = self.STOP_TOKEN
247
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.STOP_TOKEN)
248
+
249
+ # Build the context
250
+ if len(conditioning_signal.shape) != 4:
251
+ conditioning_signal = conditioning_signal.unsqueeze(1)
252
+ cond_embs = []
253
+ for i in range(conditioning_signal.shape[1]):
254
+ cond_embs.append(self.minicoder(conditioning_signal[:, i]))
255
+ cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
256
+ enc_text = self.encoder(text_codes, return_embeddings=True)
257
+ context = torch.cat([cond_emb, enc_text], dim=1)
258
+
259
+ # Execute the decoder
260
+ dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
261
+ dec = self.decoder(dec_inputs, context=context)
262
+ if not return_loss:
263
+ return dec
264
+ loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
265
+ return loss_mel
266
+
267
+ def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
268
+ if not hasattr(self, 'inference_model'):
269
+ self.inference_model = InferenceModel(self)
270
+
271
+ if len(conditioning_signal.shape) != 4:
272
+ conditioning_signal = conditioning_signal.unsqueeze(1)
273
+ cond_embs = []
274
+ for i in range(conditioning_signal.shape[1]):
275
+ cond_embs.append(self.minicoder(conditioning_signal[:, i]))
276
+ cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
277
+ enc_text = self.encoder(text_codes, return_embeddings=True)
278
+ context = torch.cat([cond_emb, enc_text], dim=1)
279
+ self.inference_model.store_context(context)
280
+
281
+ gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
282
+ max_length=250, output_attentions=False, return_dict_in_generate=True,
283
+ **hf_generate_kwargs)
284
+ return gen.sequences
285
+
286
+
287
+ if __name__ == '__main__':
288
+ codegen = AutoregressiveCodegen(1024, 20)
289
+ codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
290
+ codegen(torch.randint(0,256, (2,200)),
291
+ torch.randn(2,80,120),
292
+ torch.randint(0,8192, (2,350)),
293
+ torch.tensor([192,350]))