Spaces:
Runtime error
Runtime error
Anton Forsman
commited on
Commit
•
ed6fc19
1
Parent(s):
9078865
Uploaded model
Browse files
GPTTTS.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, AutoConfig
|
2 |
+
from encodec import EncodecModel
|
3 |
+
from encodec.utils import convert_audio
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import re
|
7 |
+
|
8 |
+
class GPTTTS(PreTrainedModel):
|
9 |
+
def __init__(self, *model_args, **model_kwargs):
|
10 |
+
super().__init__(AutoConfig.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice"), *model_args, **model_kwargs)
|
11 |
+
self.model = AutoModelForCausalLM.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice")
|
12 |
+
self.encodec_model = EncodecModel.encodec_model_24khz()
|
13 |
+
self.encodec_model.set_target_bandwidth(1.5)
|
14 |
+
self.sample_rate = self.encodec_model.sample_rate
|
15 |
+
|
16 |
+
def forward(self, input_ids):
|
17 |
+
#decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
|
18 |
+
#decoded = input_text
|
19 |
+
# Get all audio_token_
|
20 |
+
#pattern = r'audio_token_(\d+)'
|
21 |
+
#audio_tokens = re.findall(pattern, decoded)
|
22 |
+
#audio_tokens = [int(token) for token in audio_tokens]
|
23 |
+
|
24 |
+
tokens = self.model.generate(input_ids, do_sample=True, max_length=1024, temperature=1, top_k=50, top_p=0.95)[0]
|
25 |
+
# Get all tokens which are larger than 50257, and subtract 50257 from them
|
26 |
+
audio_tokens = [token - 50257 for token in tokens if token > 50257]
|
27 |
+
|
28 |
+
number_of_codebooks = 2
|
29 |
+
number_of_samples = len(audio_tokens) // number_of_codebooks
|
30 |
+
frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long)
|
31 |
+
for sample in range(number_of_samples):
|
32 |
+
for codebook in range(number_of_codebooks):
|
33 |
+
frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook]
|
34 |
+
|
35 |
+
frames = [(frame, None)]
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
wav = self.encodec_model.decode(frames)
|
39 |
+
|
40 |
+
return wav[0, :, :]
|
41 |
+
|
42 |
+
|
43 |
+
class GPTTTSTokenizer(PreTrainedTokenizer):
|
44 |
+
def __init__(self, *args, **kwargs):
|
45 |
+
super().__init__(*args, **kwargs)
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained("anforsm/distilgpt2-finetuned-common-voice")
|
47 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
48 |
+
|
49 |
+
def tokenize(self, text, *args, **kwargs):
|
50 |
+
prompt = f"text: {text}\nsound:"
|
51 |
+
return self.tokenizer(prompt, return_tensors="pt")
|
52 |
+
|
53 |
+
def _tokenize(self, *args, **kwargs):
|
54 |
+
return self.tokenize(*args, **kwargs)
|
55 |
+
|
56 |
+
def convert_tokens_to_ids(self, tokens):
|
57 |
+
return tokens["input_ids"]
|
58 |
+
|
59 |
+
def convert_ids_to_tokens(self, ids):
|
60 |
+
return self.tokenizer.decode(ids[0], skip_special_tokens=True)
|
61 |
+
|
62 |
+
def _batch_encode_plus(self, *args, **kwargs):
|
63 |
+
return self.tokenize(*args, **kwargs)
|
64 |
+
|
65 |
+
def _encode_plus(self, *args, **kwargs):
|
66 |
+
return self.tokenize(*args, **kwargs)
|
67 |
+
|
68 |
+
|
69 |
+
def save_vocabulary(self, *args, **kwargs):
|
70 |
+
return self.tokenizer.save_vocabulary(*args, **kwargs)
|
71 |
+
|