Add redaction support
Browse files- tortoise/api.py +22 -4
- tortoise/utils/wav2vec_alignment.py +82 -0
tortoise/api.py
CHANGED
@@ -19,6 +19,7 @@ from tortoise.models.vocoder import UnivNetGenerator
|
|
19 |
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
|
20 |
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
21 |
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
|
|
22 |
|
23 |
pbar = None
|
24 |
|
@@ -158,11 +159,23 @@ def classify_audio_clip(clip):
|
|
158 |
class TextToSpeech:
|
159 |
"""
|
160 |
Main entry point into Tortoise.
|
161 |
-
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
162 |
-
GPU OOM errors. Larger numbers generates slightly faster.
|
163 |
"""
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
self.autoregressive_batch_size = autoregressive_batch_size
|
|
|
|
|
|
|
|
|
166 |
self.tokenizer = VoiceBpeTokenizer()
|
167 |
download_models()
|
168 |
|
@@ -380,7 +393,6 @@ class TextToSpeech:
|
|
380 |
wav_candidates = []
|
381 |
self.diffusion = self.diffusion.cuda()
|
382 |
self.vocoder = self.vocoder.cuda()
|
383 |
-
diffusion_conds =
|
384 |
for b in range(best_results.shape[0]):
|
385 |
codes = best_results[b].unsqueeze(0)
|
386 |
latents = best_latents[b].unsqueeze(0)
|
@@ -403,6 +415,12 @@ class TextToSpeech:
|
|
403 |
self.diffusion = self.diffusion.cpu()
|
404 |
self.vocoder = self.vocoder.cpu()
|
405 |
|
|
|
|
|
|
|
|
|
|
|
406 |
if len(wav_candidates) > 1:
|
407 |
return wav_candidates
|
408 |
return wav_candidates[0]
|
|
|
|
19 |
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
|
20 |
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
21 |
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
22 |
+
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
23 |
|
24 |
pbar = None
|
25 |
|
|
|
159 |
class TextToSpeech:
|
160 |
"""
|
161 |
Main entry point into Tortoise.
|
|
|
|
|
162 |
"""
|
163 |
+
|
164 |
+
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
|
165 |
+
"""
|
166 |
+
Constructor
|
167 |
+
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
168 |
+
GPU OOM errors. Larger numbers generates slightly faster.
|
169 |
+
:param models_dir: Where model weights are stored. This should only be specified if you are providing your own
|
170 |
+
models, otherwise use the defaults.
|
171 |
+
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
|
172 |
+
(but are still rendered by the model). This can be used for prompt engineering.
|
173 |
+
"""
|
174 |
self.autoregressive_batch_size = autoregressive_batch_size
|
175 |
+
self.enable_redaction = enable_redaction
|
176 |
+
if self.enable_redaction:
|
177 |
+
self.aligner = Wav2VecAlignment()
|
178 |
+
|
179 |
self.tokenizer = VoiceBpeTokenizer()
|
180 |
download_models()
|
181 |
|
|
|
393 |
wav_candidates = []
|
394 |
self.diffusion = self.diffusion.cuda()
|
395 |
self.vocoder = self.vocoder.cuda()
|
|
|
396 |
for b in range(best_results.shape[0]):
|
397 |
codes = best_results[b].unsqueeze(0)
|
398 |
latents = best_latents[b].unsqueeze(0)
|
|
|
415 |
self.diffusion = self.diffusion.cpu()
|
416 |
self.vocoder = self.vocoder.cpu()
|
417 |
|
418 |
+
def potentially_redact(self, clip, text):
|
419 |
+
if self.enable_redaction:
|
420 |
+
return self.aligner.redact(clip, text)
|
421 |
+
return clip
|
422 |
+
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
423 |
if len(wav_candidates) > 1:
|
424 |
return wav_candidates
|
425 |
return wav_candidates[0]
|
426 |
+
|
tortoise/utils/wav2vec_alignment.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
4 |
+
|
5 |
+
from tortoise.utils.audio import load_audio
|
6 |
+
|
7 |
+
|
8 |
+
class Wav2VecAlignment:
|
9 |
+
def __init__(self):
|
10 |
+
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
11 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
12 |
+
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
13 |
+
|
14 |
+
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3):
|
15 |
+
orig_len = audio.shape[-1]
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
self.model = self.model.cuda()
|
19 |
+
audio = audio.to('cuda')
|
20 |
+
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
|
21 |
+
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
22 |
+
logits = self.model(clip_norm).logits
|
23 |
+
self.model = self.model.cpu()
|
24 |
+
|
25 |
+
logits = logits[0]
|
26 |
+
w2v_compression = orig_len // logits.shape[0]
|
27 |
+
expected_tokens = self.tokenizer.encode(expected_text)
|
28 |
+
if len(expected_tokens) == 1:
|
29 |
+
return [0] # The alignment is simple; there is only one token.
|
30 |
+
expected_tokens.pop(0) # The first token is a given.
|
31 |
+
next_expected_token = expected_tokens.pop(0)
|
32 |
+
alignments = [0]
|
33 |
+
for i, logit in enumerate(logits):
|
34 |
+
top = logit.topk(topk).indices.tolist()
|
35 |
+
if next_expected_token in top:
|
36 |
+
alignments.append(i * w2v_compression)
|
37 |
+
if len(expected_tokens) > 0:
|
38 |
+
next_expected_token = expected_tokens.pop(0)
|
39 |
+
else:
|
40 |
+
break
|
41 |
+
|
42 |
+
if len(expected_tokens) > 0:
|
43 |
+
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
|
44 |
+
f" {self.tokenizer.decode(expected_tokens)}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
return alignments
|
48 |
+
|
49 |
+
def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
|
50 |
+
if '[' not in expected_text:
|
51 |
+
return audio
|
52 |
+
splitted = expected_text.split('[')
|
53 |
+
fully_split = [splitted[0]]
|
54 |
+
for spl in splitted[1:]:
|
55 |
+
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
56 |
+
fully_split.extend(spl.split(']'))
|
57 |
+
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
58 |
+
non_redacted_intervals = []
|
59 |
+
last_point = 0
|
60 |
+
for i in range(len(fully_split)):
|
61 |
+
if i % 2 == 0:
|
62 |
+
non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
|
63 |
+
last_point += len(fully_split[i])
|
64 |
+
|
65 |
+
bare_text = ''.join(fully_split)
|
66 |
+
alignments = self.align(audio, bare_text, audio_sample_rate, topk)
|
67 |
+
if alignments is None:
|
68 |
+
return audio # Cannot redact because alignment did not succeed.
|
69 |
+
|
70 |
+
output_audio = []
|
71 |
+
for nri in non_redacted_intervals:
|
72 |
+
start, stop = nri
|
73 |
+
output_audio.append(audio[:, alignments[start]:alignments[stop]])
|
74 |
+
return torch.cat(output_audio, dim=-1)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
some_audio = load_audio('../../results/favorites/morgan_freeman_metallic_hydrogen.mp3', 24000)
|
79 |
+
aligner = Wav2VecAlignment()
|
80 |
+
text = "instead of molten iron, jupiter [and brown dwaves] have hydrogen, which [is under so much pressure that it] develops metallic properties"
|
81 |
+
redact = aligner.redact(some_audio, text)
|
82 |
+
torchaudio.save(f'test_output.wav', redact, 24000)
|