eubinecto commited on
Commit
08409ff
1 Parent(s): 6de2ea9

[#2] Pipeline now supports multiple sentences with batch_decode

Browse files
Files changed (4) hide show
  1. config.yaml +1 -0
  2. idiomify/pipeline.py +7 -5
  3. main_infer.py +7 -4
  4. main_train.py +0 -1
config.yaml CHANGED
@@ -7,6 +7,7 @@ idiomifier:
7
  max_epochs: 2
8
  batch_size: 40
9
  shuffle: true
 
10
 
11
  # for building & uploading datasets or tokenizer
12
  idioms:
 
7
  max_epochs: 2
8
  batch_size: 40
9
  shuffle: true
10
+ seed: 104
11
 
12
  # for building & uploading datasets or tokenizer
13
  idioms:
idiomify/pipeline.py CHANGED
@@ -1,5 +1,7 @@
1
 
2
  # for inference
 
 
3
  from transformers import BartTokenizer
4
 
5
  from builders import SourcesBuilder
@@ -12,13 +14,13 @@ class Pipeline:
12
  self.model = model
13
  self.builder = SourcesBuilder(tokenizer)
14
 
15
- def __call__(self, src: str, max_length=100) -> str:
16
- srcs = self.builder(literal2idiomatic=[(src, "")])
17
  pred_ids = self.model.bart.generate(
18
  inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
19
  attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
20
  decoder_start_token_id=self.model.hparams['bos_token_id'],
21
  max_length=max_length,
22
- ).squeeze() # -> (N, L_t) -> (L_t)
23
- tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
24
- return tgt
 
1
 
2
  # for inference
3
+ from typing import List
4
+
5
  from transformers import BartTokenizer
6
 
7
  from builders import SourcesBuilder
 
14
  self.model = model
15
  self.builder = SourcesBuilder(tokenizer)
16
 
17
+ def __call__(self, sents: List[str], max_length=100) -> List[str]:
18
+ srcs = self.builder(literal2idiomatic=[(sent, "") for sent in sents])
19
  pred_ids = self.model.bart.generate(
20
  inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
21
  attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
22
  decoder_start_token_id=self.model.hparams['bos_token_id'],
23
  max_length=max_length,
24
+ ) # -> (N, L_t)
25
+ tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
26
+ return tgts
main_infer.py CHANGED
@@ -1,12 +1,15 @@
 
 
 
1
  import argparse
2
- from idiomify.models import Pipeline
3
  from idiomify.fetchers import fetch_config, fetch_idiomifier
4
  from transformers import BartTokenizer
5
 
6
 
7
  def main():
8
  parser = argparse.ArgumentParser()
9
- parser.add_argument("--src", type=str,
10
  default="If there's any good to loosing my job,"
11
  " it's that I'll now be able to go to school full-time and finish my degree earlier.")
12
  args = parser.parse_args()
@@ -16,8 +19,8 @@ def main():
16
  model.eval() # this is crucial
17
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
18
  pipeline = Pipeline(model, tokenizer)
19
- src = config['src']
20
- tgt = pipeline(src=config['src'])
21
  print(src, "\n->", tgt)
22
 
23
 
 
1
+ """
2
+ This is for just a simple sanity check on the inference.
3
+ """
4
  import argparse
5
+ from idiomify.pipeline import Pipeline
6
  from idiomify.fetchers import fetch_config, fetch_idiomifier
7
  from transformers import BartTokenizer
8
 
9
 
10
  def main():
11
  parser = argparse.ArgumentParser()
12
+ parser.add_argument("--sent", type=str,
13
  default="If there's any good to loosing my job,"
14
  " it's that I'll now be able to go to school full-time and finish my degree earlier.")
15
  args = parser.parse_args()
 
19
  model.eval() # this is crucial
20
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  pipeline = Pipeline(model, tokenizer)
22
+ src = config['sent']
23
+ tgt = pipeline(sents=[config['sent']])
24
  print(src, "\n->", tgt)
25
 
26
 
main_train.py CHANGED
@@ -23,7 +23,6 @@ def main():
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
26
-
27
  # prepare the model
28
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
29
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
 
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
 
26
  # prepare the model
27
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
28
  tokenizer = BartTokenizer.from_pretrained(config['bart'])