Upgrade CLIP model and add eval_multiple
Browse files- api.py +214 -0
- do_tts.py +5 -9
- eval_multiple.py +33 -0
- models/arch_util.py +45 -1
- models/diffusion_decoder.py +2 -38
- models/text_voice_clip.py +51 -19
api.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from urllib import request
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
import progressbar
|
10 |
+
import ocotillo
|
11 |
+
|
12 |
+
from models.diffusion_decoder import DiffusionTts
|
13 |
+
from models.autoregressive import UnifiedVoice
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from models.arch_util import TorchMelSpectrogram
|
17 |
+
from models.text_voice_clip import VoiceCLIP
|
18 |
+
from models.vocoder import UnivNetGenerator
|
19 |
+
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
20 |
+
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
21 |
+
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
22 |
+
|
23 |
+
|
24 |
+
pbar = None
|
25 |
+
def download_models():
|
26 |
+
MODELS = {
|
27 |
+
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
28 |
+
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
29 |
+
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
30 |
+
}
|
31 |
+
os.makedirs('.models', exist_ok=True)
|
32 |
+
def show_progress(block_num, block_size, total_size):
|
33 |
+
global pbar
|
34 |
+
if pbar is None:
|
35 |
+
pbar = progressbar.ProgressBar(maxval=total_size)
|
36 |
+
pbar.start()
|
37 |
+
|
38 |
+
downloaded = block_num * block_size
|
39 |
+
if downloaded < total_size:
|
40 |
+
pbar.update(downloaded)
|
41 |
+
else:
|
42 |
+
pbar.finish()
|
43 |
+
pbar = None
|
44 |
+
for model_name, url in MODELS.items():
|
45 |
+
if os.path.exists(f'.models/{model_name}'):
|
46 |
+
continue
|
47 |
+
print(f'Downloading {model_name} from {url}...')
|
48 |
+
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
49 |
+
print('Done.')
|
50 |
+
|
51 |
+
|
52 |
+
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
|
53 |
+
"""
|
54 |
+
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
55 |
+
"""
|
56 |
+
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
57 |
+
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
58 |
+
conditioning_free=cond_free, conditioning_free_k=1)
|
59 |
+
|
60 |
+
|
61 |
+
def load_conditioning(clip, cond_length=132300):
|
62 |
+
gap = clip.shape[-1] - cond_length
|
63 |
+
if gap < 0:
|
64 |
+
clip = F.pad(clip, pad=(0, abs(gap)))
|
65 |
+
elif gap > 0:
|
66 |
+
rand_start = random.randint(0, gap)
|
67 |
+
clip = clip[:, rand_start:rand_start + cond_length]
|
68 |
+
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
|
69 |
+
return mel_clip.unsqueeze(0).cuda()
|
70 |
+
|
71 |
+
|
72 |
+
def fix_autoregressive_output(codes, stop_token):
|
73 |
+
"""
|
74 |
+
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
75 |
+
trained on and what the autoregressive code generator creates (which has no padding or end).
|
76 |
+
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
|
77 |
+
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
|
78 |
+
and copying out the last few codes.
|
79 |
+
|
80 |
+
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
|
81 |
+
"""
|
82 |
+
# Strip off the autoregressive stop token and add padding.
|
83 |
+
stop_token_indices = (codes == stop_token).nonzero()
|
84 |
+
if len(stop_token_indices) == 0:
|
85 |
+
print("No stop tokens found, enjoy that output of yours!")
|
86 |
+
return codes
|
87 |
+
else:
|
88 |
+
codes[stop_token_indices] = 83
|
89 |
+
stm = stop_token_indices.min().item()
|
90 |
+
codes[stm:] = 83
|
91 |
+
if stm - 3 < codes.shape[0]:
|
92 |
+
codes[-3] = 45
|
93 |
+
codes[-2] = 45
|
94 |
+
codes[-1] = 248
|
95 |
+
|
96 |
+
return codes
|
97 |
+
|
98 |
+
|
99 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
|
100 |
+
"""
|
101 |
+
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
102 |
+
"""
|
103 |
+
with torch.no_grad():
|
104 |
+
cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
|
105 |
+
# Pad MEL to multiples of 32
|
106 |
+
msl = mel_codes.shape[-1]
|
107 |
+
dsl = 32
|
108 |
+
gap = dsl - (msl % dsl)
|
109 |
+
if gap > 0:
|
110 |
+
mel = torch.nn.functional.pad(mel_codes, (0, gap))
|
111 |
+
|
112 |
+
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
113 |
+
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
|
114 |
+
if mean:
|
115 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
|
116 |
+
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
117 |
+
else:
|
118 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
119 |
+
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
120 |
+
|
121 |
+
|
122 |
+
class TextToSpeech:
|
123 |
+
def __init__(self, autoregressive_batch_size=32):
|
124 |
+
self.autoregressive_batch_size = autoregressive_batch_size
|
125 |
+
self.tokenizer = VoiceBpeTokenizer()
|
126 |
+
download_models()
|
127 |
+
|
128 |
+
self.autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30,
|
129 |
+
model_dim=1024,
|
130 |
+
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
131 |
+
train_solo_embeddings=False,
|
132 |
+
average_conditioning_embeddings=True).cpu().eval()
|
133 |
+
self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
134 |
+
|
135 |
+
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
136 |
+
text_seq_len=350, text_heads=8,
|
137 |
+
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
|
138 |
+
use_xformers=True).cpu().eval()
|
139 |
+
self.clip.load_state_dict(torch.load('.models/clip.pth'))
|
140 |
+
|
141 |
+
self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
|
142 |
+
channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3],
|
143 |
+
token_conditioning_resolutions=[1, 4, 8],
|
144 |
+
dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2,
|
145 |
+
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
|
146 |
+
conditioning_expansion=1).cpu().eval()
|
147 |
+
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
148 |
+
|
149 |
+
self.vocoder = UnivNetGenerator().cpu()
|
150 |
+
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
151 |
+
self.vocoder.eval(inference=True)
|
152 |
+
|
153 |
+
def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True):
|
154 |
+
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
155 |
+
text = F.pad(text, (0, 1)) # This may not be necessary.
|
156 |
+
|
157 |
+
conds = []
|
158 |
+
if not isinstance(voice_samples, list):
|
159 |
+
voice_samples = [voice_samples]
|
160 |
+
for vs in voice_samples:
|
161 |
+
conds.append(load_conditioning(vs))
|
162 |
+
conds = torch.stack(conds, dim=1)
|
163 |
+
cond_diffusion = voice_samples[0].cuda()
|
164 |
+
# The diffusion model expects = 88200 conditioning samples.
|
165 |
+
if cond_diffusion.shape[-1] < 88200:
|
166 |
+
cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1]))
|
167 |
+
else:
|
168 |
+
cond_diffusion = cond_diffusion[:, :88200]
|
169 |
+
|
170 |
+
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free)
|
171 |
+
|
172 |
+
with torch.no_grad():
|
173 |
+
samples = []
|
174 |
+
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
175 |
+
stop_mel_token = self.autoregressive.stop_mel_token
|
176 |
+
self.autoregressive = self.autoregressive.cuda()
|
177 |
+
for b in tqdm(range(num_batches)):
|
178 |
+
codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True,
|
179 |
+
top_k=50, top_p=.95,
|
180 |
+
temperature=.9,
|
181 |
+
num_return_sequences=self.autoregressive_batch_size,
|
182 |
+
length_penalty=1)
|
183 |
+
padding_needed = 250 - codes.shape[1]
|
184 |
+
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
185 |
+
samples.append(codes)
|
186 |
+
self.autoregressive = self.autoregressive.cpu()
|
187 |
+
|
188 |
+
clip_results = []
|
189 |
+
self.clip = self.clip.cuda()
|
190 |
+
for batch in samples:
|
191 |
+
for i in range(batch.shape[0]):
|
192 |
+
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
193 |
+
clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
|
194 |
+
clip_results = torch.cat(clip_results, dim=0)
|
195 |
+
samples = torch.cat(samples, dim=0)
|
196 |
+
best_results = samples[torch.topk(clip_results, k=k).indices]
|
197 |
+
self.clip = self.clip.cpu()
|
198 |
+
del samples
|
199 |
+
|
200 |
+
print("Performing vocoding..")
|
201 |
+
wav_candidates = []
|
202 |
+
self.diffusion = self.diffusion.cuda()
|
203 |
+
self.vocoder = self.vocoder.cuda()
|
204 |
+
for b in range(best_results.shape[0]):
|
205 |
+
code = best_results[b].unsqueeze(0)
|
206 |
+
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False)
|
207 |
+
wav = self.vocoder.inference(mel)
|
208 |
+
wav_candidates.append(wav.cpu())
|
209 |
+
self.diffusion = self.diffusion.cpu()
|
210 |
+
self.vocoder = self.vocoder.cpu()
|
211 |
+
|
212 |
+
if len(wav_candidates) > 1:
|
213 |
+
return wav_candidates
|
214 |
+
return wav_candidates[0]
|
do_tts.py
CHANGED
@@ -138,8 +138,8 @@ if __name__ == '__main__':
|
|
138 |
parser = argparse.ArgumentParser()
|
139 |
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
140 |
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
|
141 |
-
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=
|
142 |
-
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=
|
143 |
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
|
144 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
145 |
args = parser.parse_args()
|
@@ -179,19 +179,15 @@ if __name__ == '__main__':
|
|
179 |
del autoregressive
|
180 |
|
181 |
print("Loading CLIP..")
|
182 |
-
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=
|
183 |
-
num_speech_tokens=8192, speech_enc_depth=
|
184 |
clip.load_state_dict(torch.load('.models/clip.pth'))
|
185 |
print("Performing CLIP filtering..")
|
186 |
clip_results = []
|
187 |
for batch in samples:
|
188 |
for i in range(batch.shape[0]):
|
189 |
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
190 |
-
text
|
191 |
-
clip_results.append(clip(text.repeat(batch.shape[0], 1),
|
192 |
-
torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
|
193 |
-
batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),
|
194 |
-
return_loss=False))
|
195 |
clip_results = torch.cat(clip_results, dim=0)
|
196 |
samples = torch.cat(samples, dim=0)
|
197 |
best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices]
|
|
|
138 |
parser = argparse.ArgumentParser()
|
139 |
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
140 |
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
|
141 |
+
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
|
142 |
+
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
|
143 |
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
|
144 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
145 |
args = parser.parse_args()
|
|
|
179 |
del autoregressive
|
180 |
|
181 |
print("Loading CLIP..")
|
182 |
+
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8,
|
183 |
+
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, use_xformers=True).cuda().eval()
|
184 |
clip.load_state_dict(torch.load('.models/clip.pth'))
|
185 |
print("Performing CLIP filtering..")
|
186 |
clip_results = []
|
187 |
for batch in samples:
|
188 |
for i in range(batch.shape[0]):
|
189 |
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
190 |
+
clip_results.append(clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
|
|
|
|
|
|
|
|
|
191 |
clip_results = torch.cat(clip_results, dim=0)
|
192 |
samples = torch.cat(samples, dim=0)
|
193 |
best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices]
|
eval_multiple.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
from api import TextToSpeech
|
6 |
+
from utils.audio import load_audio
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
|
10 |
+
outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline'
|
11 |
+
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
12 |
+
|
13 |
+
os.makedirs(outpath, exist_ok=True)
|
14 |
+
os.makedirs(outpath_real, exist_ok=True)
|
15 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
16 |
+
lines = [l.strip().split('\t') for l in f.readlines()]
|
17 |
+
|
18 |
+
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
19 |
+
tts = TextToSpeech()
|
20 |
+
for e, line in enumerate(lines):
|
21 |
+
transcript = line[0]
|
22 |
+
if len(transcript) > 120:
|
23 |
+
continue # We need to support this, but cannot yet.
|
24 |
+
path = os.path.join(os.path.dirname(fname), line[1])
|
25 |
+
cond_audio = load_audio(path, 22050)
|
26 |
+
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
27 |
+
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True)
|
28 |
+
down = torchaudio.functional.resample(sample, 24000, 22050)
|
29 |
+
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
30 |
+
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
31 |
+
recorder.write(f'{transcript}\t{fout_path}\n')
|
32 |
+
recorder.flush()
|
33 |
+
recorder.close()
|
models/arch_util.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
1 |
import math
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio
|
|
|
7 |
|
8 |
|
9 |
def zero_module(module):
|
@@ -316,4 +318,46 @@ class TorchMelSpectrogram(nn.Module):
|
|
316 |
if self.mel_norms is not None:
|
317 |
self.mel_norms = self.mel_norms.to(mel.device)
|
318 |
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
319 |
-
return mel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
import math
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
import torchaudio
|
8 |
+
from x_transformers import ContinuousTransformerWrapper
|
9 |
|
10 |
|
11 |
def zero_module(module):
|
|
|
318 |
if self.mel_norms is not None:
|
319 |
self.mel_norms = self.mel_norms.to(mel.device)
|
320 |
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
321 |
+
return mel
|
322 |
+
|
323 |
+
|
324 |
+
class CheckpointedLayer(nn.Module):
|
325 |
+
"""
|
326 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
327 |
+
checkpoint for all other args.
|
328 |
+
"""
|
329 |
+
def __init__(self, wrap):
|
330 |
+
super().__init__()
|
331 |
+
self.wrap = wrap
|
332 |
+
|
333 |
+
def forward(self, x, *args, **kwargs):
|
334 |
+
for k, v in kwargs.items():
|
335 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
336 |
+
partial = functools.partial(self.wrap, **kwargs)
|
337 |
+
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
338 |
+
|
339 |
+
|
340 |
+
class CheckpointedXTransformerEncoder(nn.Module):
|
341 |
+
"""
|
342 |
+
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
343 |
+
to channels-last that XTransformer expects.
|
344 |
+
"""
|
345 |
+
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
|
346 |
+
super().__init__()
|
347 |
+
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
348 |
+
self.needs_permute = needs_permute
|
349 |
+
self.exit_permute = exit_permute
|
350 |
+
|
351 |
+
if not checkpoint:
|
352 |
+
return
|
353 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
354 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
355 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
356 |
+
|
357 |
+
def forward(self, x, **kwargs):
|
358 |
+
if self.needs_permute:
|
359 |
+
x = x.permute(0,2,1)
|
360 |
+
h = self.transformer(x, **kwargs)
|
361 |
+
if self.exit_permute:
|
362 |
+
h = h.permute(0,2,1)
|
363 |
+
return h
|
models/diffusion_decoder.py
CHANGED
@@ -15,7 +15,8 @@ from torch.nn import Linear
|
|
15 |
from torch.utils.checkpoint import checkpoint
|
16 |
from x_transformers import ContinuousTransformerWrapper, Encoder
|
17 |
|
18 |
-
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
|
|
|
19 |
|
20 |
|
21 |
def is_latent(t):
|
@@ -157,43 +158,6 @@ class ResBlock(TimestepBlock):
|
|
157 |
return self.skip_connection(x) + h
|
158 |
|
159 |
|
160 |
-
class CheckpointedLayer(nn.Module):
|
161 |
-
"""
|
162 |
-
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
163 |
-
checkpoint for all other args.
|
164 |
-
"""
|
165 |
-
def __init__(self, wrap):
|
166 |
-
super().__init__()
|
167 |
-
self.wrap = wrap
|
168 |
-
|
169 |
-
def forward(self, x, *args, **kwargs):
|
170 |
-
for k, v in kwargs.items():
|
171 |
-
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
172 |
-
partial = functools.partial(self.wrap, **kwargs)
|
173 |
-
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
174 |
-
|
175 |
-
|
176 |
-
class CheckpointedXTransformerEncoder(nn.Module):
|
177 |
-
"""
|
178 |
-
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
179 |
-
to channels-last that XTransformer expects.
|
180 |
-
"""
|
181 |
-
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
182 |
-
super().__init__()
|
183 |
-
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
184 |
-
self.needs_permute = needs_permute
|
185 |
-
|
186 |
-
for i in range(len(self.transformer.attn_layers.layers)):
|
187 |
-
n, b, r = self.transformer.attn_layers.layers[i]
|
188 |
-
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
189 |
-
|
190 |
-
def forward(self, x, **kwargs):
|
191 |
-
if self.needs_permute:
|
192 |
-
x = x.permute(0,2,1)
|
193 |
-
h = self.transformer(x, **kwargs)
|
194 |
-
return h.permute(0,2,1)
|
195 |
-
|
196 |
-
|
197 |
class DiffusionTts(nn.Module):
|
198 |
"""
|
199 |
The full UNet model with attention and timestep embedding.
|
|
|
15 |
from torch.utils.checkpoint import checkpoint
|
16 |
from x_transformers import ContinuousTransformerWrapper, Encoder
|
17 |
|
18 |
+
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \
|
19 |
+
CheckpointedXTransformerEncoder
|
20 |
|
21 |
|
22 |
def is_latent(t):
|
|
|
158 |
return self.skip_connection(x) + h
|
159 |
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
class DiffusionTts(nn.Module):
|
162 |
"""
|
163 |
The full UNet model with attention and timestep embedding.
|
models/text_voice_clip.py
CHANGED
@@ -2,6 +2,9 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch import einsum
|
|
|
|
|
|
|
5 |
from models.transformer import Transformer
|
6 |
|
7 |
|
@@ -13,7 +16,6 @@ def masked_mean(t, mask, dim = 1):
|
|
13 |
t = t.masked_fill(~mask[:, :, None], 0.)
|
14 |
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
15 |
|
16 |
-
|
17 |
class VoiceCLIP(nn.Module):
|
18 |
"""
|
19 |
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
|
@@ -39,40 +41,69 @@ class VoiceCLIP(nn.Module):
|
|
39 |
text_mask_percentage=0,
|
40 |
voice_mask_percentage=0,
|
41 |
wav_token_compression=1024,
|
|
|
42 |
):
|
43 |
super().__init__()
|
44 |
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
45 |
-
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
46 |
-
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
47 |
-
heads=text_heads)
|
48 |
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
49 |
|
50 |
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
51 |
-
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
52 |
-
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
53 |
-
depth=speech_enc_depth, heads=speech_heads)
|
54 |
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
self.temperature = nn.Parameter(torch.tensor(1.))
|
57 |
self.text_mask_percentage = text_mask_percentage
|
58 |
self.voice_mask_percentage = voice_mask_percentage
|
59 |
self.wav_token_compression = wav_token_compression
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def forward(
|
62 |
self,
|
63 |
text,
|
64 |
-
text_lengths,
|
65 |
speech_tokens,
|
66 |
-
wav_lengths,
|
67 |
return_loss=False
|
68 |
):
|
69 |
-
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
70 |
-
# chopping the inputs by the maximum actual length.
|
71 |
-
max_text_len = text_lengths.max()
|
72 |
-
text = text[:, :max_text_len]
|
73 |
-
max_mel_len = wav_lengths.max() // self.wav_token_compression
|
74 |
-
speech_tokens = speech_tokens[:, :max_mel_len]
|
75 |
-
|
76 |
b, device = text.shape[0], text.device
|
77 |
if self.training:
|
78 |
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
@@ -82,10 +113,11 @@ class VoiceCLIP(nn.Module):
|
|
82 |
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
83 |
|
84 |
text_emb = self.text_emb(text)
|
85 |
-
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
86 |
-
|
87 |
speech_emb = self.speech_emb(speech_tokens)
|
88 |
-
|
|
|
|
|
|
|
89 |
|
90 |
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
91 |
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch import einsum
|
5 |
+
from x_transformers import Encoder
|
6 |
+
|
7 |
+
from models.arch_util import CheckpointedXTransformerEncoder
|
8 |
from models.transformer import Transformer
|
9 |
|
10 |
|
|
|
16 |
t = t.masked_fill(~mask[:, :, None], 0.)
|
17 |
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
18 |
|
|
|
19 |
class VoiceCLIP(nn.Module):
|
20 |
"""
|
21 |
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
|
|
|
41 |
text_mask_percentage=0,
|
42 |
voice_mask_percentage=0,
|
43 |
wav_token_compression=1024,
|
44 |
+
use_xformers=False,
|
45 |
):
|
46 |
super().__init__()
|
47 |
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
|
|
|
|
|
|
48 |
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
49 |
|
50 |
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
|
|
|
|
|
|
51 |
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
52 |
|
53 |
+
if use_xformers:
|
54 |
+
self.text_transformer = CheckpointedXTransformerEncoder(
|
55 |
+
needs_permute=False,
|
56 |
+
exit_permute=False,
|
57 |
+
max_seq_len=-1,
|
58 |
+
use_pos_emb=False,
|
59 |
+
attn_layers=Encoder(
|
60 |
+
dim=dim_text,
|
61 |
+
depth=text_enc_depth,
|
62 |
+
heads=text_heads,
|
63 |
+
ff_dropout=.1,
|
64 |
+
ff_mult=2,
|
65 |
+
attn_dropout=.1,
|
66 |
+
use_rmsnorm=True,
|
67 |
+
ff_glu=True,
|
68 |
+
rotary_pos_emb=True,
|
69 |
+
))
|
70 |
+
self.speech_transformer = CheckpointedXTransformerEncoder(
|
71 |
+
needs_permute=False,
|
72 |
+
exit_permute=False,
|
73 |
+
max_seq_len=-1,
|
74 |
+
use_pos_emb=False,
|
75 |
+
attn_layers=Encoder(
|
76 |
+
dim=dim_speech,
|
77 |
+
depth=speech_enc_depth,
|
78 |
+
heads=speech_heads,
|
79 |
+
ff_dropout=.1,
|
80 |
+
ff_mult=2,
|
81 |
+
attn_dropout=.1,
|
82 |
+
use_rmsnorm=True,
|
83 |
+
ff_glu=True,
|
84 |
+
rotary_pos_emb=True,
|
85 |
+
))
|
86 |
+
else:
|
87 |
+
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
88 |
+
heads=text_heads)
|
89 |
+
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
90 |
+
depth=speech_enc_depth, heads=speech_heads)
|
91 |
+
|
92 |
self.temperature = nn.Parameter(torch.tensor(1.))
|
93 |
self.text_mask_percentage = text_mask_percentage
|
94 |
self.voice_mask_percentage = voice_mask_percentage
|
95 |
self.wav_token_compression = wav_token_compression
|
96 |
+
self.xformers = use_xformers
|
97 |
+
if not use_xformers:
|
98 |
+
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
99 |
+
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
100 |
|
101 |
def forward(
|
102 |
self,
|
103 |
text,
|
|
|
104 |
speech_tokens,
|
|
|
105 |
return_loss=False
|
106 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
b, device = text.shape[0], text.device
|
108 |
if self.training:
|
109 |
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
|
|
113 |
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
114 |
|
115 |
text_emb = self.text_emb(text)
|
|
|
|
|
116 |
speech_emb = self.speech_emb(speech_tokens)
|
117 |
+
|
118 |
+
if not self.xformers:
|
119 |
+
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
120 |
+
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
121 |
|
122 |
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
123 |
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|