[#7] training & fetching m-1-3 is ready
Browse files- config.yaml +6 -5
- explore/explore_fetch_tokenizer.py +4 -0
- idiomify/fetchers.py +1 -0
- main_train.py +7 -5
config.yaml
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
# for training an idiomifier
|
2 |
idiomifier:
|
3 |
-
ver: m-1-
|
4 |
-
desc:
|
5 |
bart: facebook/bart-base
|
6 |
lr: 0.0001
|
7 |
-
literal2idiomatic_ver: d-1-
|
8 |
-
idioms_ver: d-1-
|
9 |
-
|
|
|
10 |
batch_size: 40
|
11 |
shuffle: true
|
12 |
seed: 104
|
|
|
1 |
# for training an idiomifier
|
2 |
idiomifier:
|
3 |
+
ver: m-1-3
|
4 |
+
desc: Just overfitting on PIE dataset, but now with <idiom> & </idiom> special tokens.
|
5 |
bart: facebook/bart-base
|
6 |
lr: 0.0001
|
7 |
+
literal2idiomatic_ver: d-1-3
|
8 |
+
idioms_ver: d-1-3
|
9 |
+
tokenizer_ver: t-1-1
|
10 |
+
max_epochs: 3
|
11 |
batch_size: 40
|
12 |
shuffle: true
|
13 |
seed: 104
|
explore/explore_fetch_tokenizer.py
CHANGED
@@ -12,6 +12,9 @@ def main():
|
|
12 |
print(tokenizer.unk_token)
|
13 |
print(tokenizer.additional_special_tokens) # this should have been added
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
"""
|
17 |
<s>
|
@@ -22,6 +25,7 @@ def main():
|
|
22 |
<pad>
|
23 |
<unk>
|
24 |
['<idiom>', '</idiom>']
|
|
|
25 |
"""
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
12 |
print(tokenizer.unk_token)
|
13 |
print(tokenizer.additional_special_tokens) # this should have been added
|
14 |
|
15 |
+
# the size of the vocab
|
16 |
+
print(len(tokenizer))
|
17 |
+
|
18 |
|
19 |
"""
|
20 |
<s>
|
|
|
25 |
<pad>
|
26 |
<unk>
|
27 |
['<idiom>', '</idiom>']
|
28 |
+
50267
|
29 |
"""
|
30 |
|
31 |
if __name__ == '__main__':
|
idiomify/fetchers.py
CHANGED
@@ -60,6 +60,7 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
|
|
60 |
artifact_dir = artifact.download(root=idiomifier_dir(ver))
|
61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
62 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
|
|
63 |
model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
64 |
return model
|
65 |
|
|
|
60 |
artifact_dir = artifact.download(root=idiomifier_dir(ver))
|
61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
62 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
63 |
+
bart.resize_embeddings(config['vocab_size'])
|
64 |
model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
65 |
return model
|
66 |
|
main_train.py
CHANGED
@@ -5,9 +5,9 @@ import argparse
|
|
5 |
import pytorch_lightning as pl
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
-
from transformers import
|
9 |
from idiomify.datamodules import IdiomifyDataModule
|
10 |
-
from idiomify.fetchers import fetch_config
|
11 |
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
@@ -23,12 +23,13 @@ def main():
|
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
26 |
-
# prepare
|
27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
28 |
-
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
29 |
-
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
30 |
# prepare the datamodule
|
31 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
|
|
|
|
|
|
32 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
33 |
logger = WandbLogger(log_model=False)
|
34 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
@@ -44,6 +45,7 @@ def main():
|
|
44 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
45 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
46 |
trainer.save_checkpoint(str(ckpt_path))
|
|
|
47 |
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
48 |
artifact.add_file(str(ckpt_path))
|
49 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
|
|
5 |
import pytorch_lightning as pl
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
from transformers import BartForConditionalGeneration
|
9 |
from idiomify.datamodules import IdiomifyDataModule
|
10 |
+
from idiomify.fetchers import fetch_config, fetch_tokenizer
|
11 |
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
|
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
26 |
+
# prepare a pre-trained BART
|
27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
|
|
|
|
28 |
# prepare the datamodule
|
29 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
30 |
+
tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
|
31 |
+
bart.resize_token_embeddings(len(tokenizer)) # because new tokens are added, this process is necessary
|
32 |
+
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
33 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
34 |
logger = WandbLogger(log_model=False)
|
35 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
|
|
45 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
46 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
47 |
trainer.save_checkpoint(str(ckpt_path))
|
48 |
+
config['vocab_size'] = len(tokenizer) # this will be needed to fetch a pretrained idiomifier later
|
49 |
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
50 |
artifact.add_file(str(ckpt_path))
|
51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|