JuanJoseMV
commited on
Commit
•
8f5d925
1
Parent(s):
71eacb0
hotfix
Browse files- NeuralTextGenerator.py +2 -2
- app.py +3 -3
NeuralTextGenerator.py
CHANGED
@@ -20,7 +20,7 @@ DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
20 |
|
21 |
|
22 |
class BertTextGenerator:
|
23 |
-
def __init__(self, model_version, device=DEFAULT_DEVICE, use_apex=APEX_AVAILABLE, use_fast=True,
|
24 |
do_basic_tokenize=True):
|
25 |
"""
|
26 |
Wrapper of a BERT model from AutoModelForMaskedLM from huggingfaces.
|
@@ -47,7 +47,7 @@ class BertTextGenerator:
|
|
47 |
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
|
48 |
loss_scale="dynamic")
|
49 |
|
50 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
51 |
use_fast=use_fast,
|
52 |
do_basic_tokenize=do_basic_tokenize) # added to avoid splitting of unused tokens
|
53 |
self.num_attention_masks = len(self.model.base_model.base_model.encoder.layer)
|
|
|
20 |
|
21 |
|
22 |
class BertTextGenerator:
|
23 |
+
def __init__(self, model_version, tokenizer, device=DEFAULT_DEVICE, use_apex=APEX_AVAILABLE, use_fast=True,
|
24 |
do_basic_tokenize=True):
|
25 |
"""
|
26 |
Wrapper of a BERT model from AutoModelForMaskedLM from huggingfaces.
|
|
|
47 |
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level="O2", keep_batchnorm_fp32=True,
|
48 |
loss_scale="dynamic")
|
49 |
|
50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, do_lower_case="uncased" in model_version,
|
51 |
use_fast=use_fast,
|
52 |
do_basic_tokenize=do_basic_tokenize) # added to avoid splitting of unused tokens
|
53 |
self.num_attention_masks = len(self.model.base_model.base_model.encoder.layer)
|
app.py
CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
|
|
2 |
from NeuralTextGenerator import BertTextGenerator
|
3 |
|
4 |
model_name = "cardiffnlp/twitter-xlm-roberta-base" #"dbmdz/bert-base-italian-uncased"
|
5 |
-
en_model = BertTextGenerator(model_name)
|
6 |
|
7 |
finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
|
8 |
-
finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name)
|
9 |
|
10 |
finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
|
11 |
-
finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name)
|
12 |
|
13 |
special_tokens = [
|
14 |
'[POSITIVE-0]',
|
|
|
2 |
from NeuralTextGenerator import BertTextGenerator
|
3 |
|
4 |
model_name = "cardiffnlp/twitter-xlm-roberta-base" #"dbmdz/bert-base-italian-uncased"
|
5 |
+
en_model = BertTextGenerator(model_name, tokenizer='xlm-roberta')
|
6 |
|
7 |
finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
|
8 |
+
finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name, tokenizer='bert-base-uncased')
|
9 |
|
10 |
finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
|
11 |
+
finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name, tokenizer='xlm-roberta')
|
12 |
|
13 |
special_tokens = [
|
14 |
'[POSITIVE-0]',
|