[#2] Pipeline now supports multiple sentences with batch_decode
Browse files- config.yaml +1 -0
- idiomify/pipeline.py +7 -5
- main_infer.py +7 -4
- main_train.py +0 -1
config.yaml
CHANGED
@@ -7,6 +7,7 @@ idiomifier:
|
|
7 |
max_epochs: 2
|
8 |
batch_size: 40
|
9 |
shuffle: true
|
|
|
10 |
|
11 |
# for building & uploading datasets or tokenizer
|
12 |
idioms:
|
|
|
7 |
max_epochs: 2
|
8 |
batch_size: 40
|
9 |
shuffle: true
|
10 |
+
seed: 104
|
11 |
|
12 |
# for building & uploading datasets or tokenizer
|
13 |
idioms:
|
idiomify/pipeline.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
|
2 |
# for inference
|
|
|
|
|
3 |
from transformers import BartTokenizer
|
4 |
|
5 |
from builders import SourcesBuilder
|
@@ -12,13 +14,13 @@ class Pipeline:
|
|
12 |
self.model = model
|
13 |
self.builder = SourcesBuilder(tokenizer)
|
14 |
|
15 |
-
def __call__(self,
|
16 |
-
srcs = self.builder(literal2idiomatic=[(
|
17 |
pred_ids = self.model.bart.generate(
|
18 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
19 |
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
20 |
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
21 |
max_length=max_length,
|
22 |
-
)
|
23 |
-
|
24 |
-
return
|
|
|
1 |
|
2 |
# for inference
|
3 |
+
from typing import List
|
4 |
+
|
5 |
from transformers import BartTokenizer
|
6 |
|
7 |
from builders import SourcesBuilder
|
|
|
14 |
self.model = model
|
15 |
self.builder = SourcesBuilder(tokenizer)
|
16 |
|
17 |
+
def __call__(self, sents: List[str], max_length=100) -> List[str]:
|
18 |
+
srcs = self.builder(literal2idiomatic=[(sent, "") for sent in sents])
|
19 |
pred_ids = self.model.bart.generate(
|
20 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
21 |
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
22 |
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
23 |
max_length=max_length,
|
24 |
+
) # -> (N, L_t)
|
25 |
+
tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
26 |
+
return tgts
|
main_infer.py
CHANGED
@@ -1,12 +1,15 @@
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
-
from idiomify.
|
3 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
4 |
from transformers import BartTokenizer
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser()
|
9 |
-
parser.add_argument("--
|
10 |
default="If there's any good to loosing my job,"
|
11 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
12 |
args = parser.parse_args()
|
@@ -16,8 +19,8 @@ def main():
|
|
16 |
model.eval() # this is crucial
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
pipeline = Pipeline(model, tokenizer)
|
19 |
-
src = config['
|
20 |
-
tgt = pipeline(
|
21 |
print(src, "\n->", tgt)
|
22 |
|
23 |
|
|
|
1 |
+
"""
|
2 |
+
This is for just a simple sanity check on the inference.
|
3 |
+
"""
|
4 |
import argparse
|
5 |
+
from idiomify.pipeline import Pipeline
|
6 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
7 |
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
def main():
|
11 |
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("--sent", type=str,
|
13 |
default="If there's any good to loosing my job,"
|
14 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
15 |
args = parser.parse_args()
|
|
|
19 |
model.eval() # this is crucial
|
20 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
pipeline = Pipeline(model, tokenizer)
|
22 |
+
src = config['sent']
|
23 |
+
tgt = pipeline(sents=[config['sent']])
|
24 |
print(src, "\n->", tgt)
|
25 |
|
26 |
|
main_train.py
CHANGED
@@ -23,7 +23,6 @@ def main():
|
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
26 |
-
|
27 |
# prepare the model
|
28 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
29 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
|
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
|
|
26 |
# prepare the model
|
27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
28 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|