eubinecto commited on
Commit
ff61478
1 Parent(s): 210581d

[#7] training & fetching m-1-3 is ready

Browse files
config.yaml CHANGED
@@ -1,12 +1,13 @@
1
  # for training an idiomifier
2
  idiomifier:
3
- ver: m-1-2
4
- desc: just overfitting the model, but on the entire PIE dataset.
5
  bart: facebook/bart-base
6
  lr: 0.0001
7
- literal2idiomatic_ver: d-1-2
8
- idioms_ver: d-1-2
9
- max_epochs: 2
 
10
  batch_size: 40
11
  shuffle: true
12
  seed: 104
 
1
  # for training an idiomifier
2
  idiomifier:
3
+ ver: m-1-3
4
+ desc: Just overfitting on PIE dataset, but now with <idiom> & </idiom> special tokens.
5
  bart: facebook/bart-base
6
  lr: 0.0001
7
+ literal2idiomatic_ver: d-1-3
8
+ idioms_ver: d-1-3
9
+ tokenizer_ver: t-1-1
10
+ max_epochs: 3
11
  batch_size: 40
12
  shuffle: true
13
  seed: 104
explore/explore_fetch_tokenizer.py CHANGED
@@ -12,6 +12,9 @@ def main():
12
  print(tokenizer.unk_token)
13
  print(tokenizer.additional_special_tokens) # this should have been added
14
 
 
 
 
15
 
16
  """
17
  <s>
@@ -22,6 +25,7 @@ def main():
22
  <pad>
23
  <unk>
24
  ['<idiom>', '</idiom>']
 
25
  """
26
 
27
  if __name__ == '__main__':
 
12
  print(tokenizer.unk_token)
13
  print(tokenizer.additional_special_tokens) # this should have been added
14
 
15
+ # the size of the vocab
16
+ print(len(tokenizer))
17
+
18
 
19
  """
20
  <s>
 
25
  <pad>
26
  <unk>
27
  ['<idiom>', '</idiom>']
28
+ 50267
29
  """
30
 
31
  if __name__ == '__main__':
idiomify/fetchers.py CHANGED
@@ -60,6 +60,7 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
60
  artifact_dir = artifact.download(root=idiomifier_dir(ver))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
62
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
 
63
  model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
64
  return model
65
 
 
60
  artifact_dir = artifact.download(root=idiomifier_dir(ver))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
62
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
63
+ bart.resize_embeddings(config['vocab_size'])
64
  model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
65
  return model
66
 
main_train.py CHANGED
@@ -5,9 +5,9 @@ import argparse
5
  import pytorch_lightning as pl
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
- from transformers import BartTokenizer, BartForConditionalGeneration
9
  from idiomify.datamodules import IdiomifyDataModule
10
- from idiomify.fetchers import fetch_config
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
13
 
@@ -23,12 +23,13 @@ def main():
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
26
- # prepare the model
27
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
28
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
29
- model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
30
  # prepare the datamodule
31
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
 
 
 
32
  datamodule = IdiomifyDataModule(config, tokenizer, run)
33
  logger = WandbLogger(log_model=False)
34
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
@@ -44,6 +45,7 @@ def main():
44
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
45
  ckpt_path = ROOT_DIR / "model.ckpt"
46
  trainer.save_checkpoint(str(ckpt_path))
 
47
  artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
48
  artifact.add_file(str(ckpt_path))
49
  run.log_artifact(artifact, aliases=["latest", config['ver']])
 
5
  import pytorch_lightning as pl
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
+ from transformers import BartForConditionalGeneration
9
  from idiomify.datamodules import IdiomifyDataModule
10
+ from idiomify.fetchers import fetch_config, fetch_tokenizer
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
13
 
 
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
26
+ # prepare a pre-trained BART
27
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
 
 
28
  # prepare the datamodule
29
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
30
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
31
+ bart.resize_token_embeddings(len(tokenizer)) # because new tokens are added, this process is necessary
32
+ model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
33
  datamodule = IdiomifyDataModule(config, tokenizer, run)
34
  logger = WandbLogger(log_model=False)
35
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
 
45
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
46
  ckpt_path = ROOT_DIR / "model.ckpt"
47
  trainer.save_checkpoint(str(ckpt_path))
48
+ config['vocab_size'] = len(tokenizer) # this will be needed to fetch a pretrained idiomifier later
49
  artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
50
  artifact.add_file(str(ckpt_path))
51
  run.log_artifact(artifact, aliases=["latest", config['ver']])