Anton Forsman commited on
Commit
ed6fc19
1 Parent(s): 9078865

Uploaded model

Browse files
Files changed (1) hide show
  1. GPTTTS.py +71 -0
GPTTTS.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, AutoConfig
2
+ from encodec import EncodecModel
3
+ from encodec.utils import convert_audio
4
+ import torch
5
+ import torchaudio
6
+ import re
7
+
8
+ class GPTTTS(PreTrainedModel):
9
+ def __init__(self, *model_args, **model_kwargs):
10
+ super().__init__(AutoConfig.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice"), *model_args, **model_kwargs)
11
+ self.model = AutoModelForCausalLM.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice")
12
+ self.encodec_model = EncodecModel.encodec_model_24khz()
13
+ self.encodec_model.set_target_bandwidth(1.5)
14
+ self.sample_rate = self.encodec_model.sample_rate
15
+
16
+ def forward(self, input_ids):
17
+ #decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
18
+ #decoded = input_text
19
+ # Get all audio_token_
20
+ #pattern = r'audio_token_(\d+)'
21
+ #audio_tokens = re.findall(pattern, decoded)
22
+ #audio_tokens = [int(token) for token in audio_tokens]
23
+
24
+ tokens = self.model.generate(input_ids, do_sample=True, max_length=1024, temperature=1, top_k=50, top_p=0.95)[0]
25
+ # Get all tokens which are larger than 50257, and subtract 50257 from them
26
+ audio_tokens = [token - 50257 for token in tokens if token > 50257]
27
+
28
+ number_of_codebooks = 2
29
+ number_of_samples = len(audio_tokens) // number_of_codebooks
30
+ frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long)
31
+ for sample in range(number_of_samples):
32
+ for codebook in range(number_of_codebooks):
33
+ frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook]
34
+
35
+ frames = [(frame, None)]
36
+
37
+ with torch.no_grad():
38
+ wav = self.encodec_model.decode(frames)
39
+
40
+ return wav[0, :, :]
41
+
42
+
43
+ class GPTTTSTokenizer(PreTrainedTokenizer):
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self.tokenizer = AutoTokenizer.from_pretrained("anforsm/distilgpt2-finetuned-common-voice")
47
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
48
+
49
+ def tokenize(self, text, *args, **kwargs):
50
+ prompt = f"text: {text}\nsound:"
51
+ return self.tokenizer(prompt, return_tensors="pt")
52
+
53
+ def _tokenize(self, *args, **kwargs):
54
+ return self.tokenize(*args, **kwargs)
55
+
56
+ def convert_tokens_to_ids(self, tokens):
57
+ return tokens["input_ids"]
58
+
59
+ def convert_ids_to_tokens(self, ids):
60
+ return self.tokenizer.decode(ids[0], skip_special_tokens=True)
61
+
62
+ def _batch_encode_plus(self, *args, **kwargs):
63
+ return self.tokenize(*args, **kwargs)
64
+
65
+ def _encode_plus(self, *args, **kwargs):
66
+ return self.tokenize(*args, **kwargs)
67
+
68
+
69
+ def save_vocabulary(self, *args, **kwargs):
70
+ return self.tokenizer.save_vocabulary(*args, **kwargs)
71
+