[#2] Eval script added. It still needs testing
Browse files- config.yaml +3 -2
- explore/explore_torchmetrics_bleu.py +28 -0
- idiomify/metrics.py +0 -4
- idiomify/models.py +20 -3
- main_eval.py +34 -0
config.yaml
CHANGED
@@ -12,9 +12,10 @@ train:
|
|
12 |
upload:
|
13 |
idioms:
|
14 |
ver: d-1-2
|
15 |
-
description: the set of idioms in the traning set of literal2idiomatic:d-1-2
|
16 |
literal2idiomatic:
|
17 |
ver: d-1-2
|
18 |
-
description: PIE data split into train & test set (80 / 20 split)
|
|
|
19 |
train_ratio: 0.8
|
20 |
seed: 104
|
|
|
12 |
upload:
|
13 |
idioms:
|
14 |
ver: d-1-2
|
15 |
+
description: the set of idioms in the traning set of literal2idiomatic:d-1-2.
|
16 |
literal2idiomatic:
|
17 |
ver: d-1-2
|
18 |
+
description: PIE data split into train & test set (80 / 20 split). There is no validation set, because I don't intend to
|
19 |
+
do hyperparameter tuning on this set.
|
20 |
train_ratio: 0.8
|
21 |
seed: 104
|
explore/explore_torchmetrics_bleu.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torchmetrics import BLEUScore
|
3 |
+
from transformers import BartTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
pairs = [
|
7 |
+
("I knew you could do it", "I knew you could do it"),
|
8 |
+
("I knew you could do it", "you knew you could do it")
|
9 |
+
]
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
|
14 |
+
metric = BLEUScore()
|
15 |
+
preds = tokenizer([pred for pred, _ in pairs])['input_ids']
|
16 |
+
targets = tokenizer([target for _, target in pairs])['input_ids']
|
17 |
+
print(preds)
|
18 |
+
print(targets)
|
19 |
+
print(metric(preds, targets))
|
20 |
+
# arghhh, so bleu score does not support tensors...
|
21 |
+
"""
|
22 |
+
AttributeError: 'int' object has no attribute 'split'
|
23 |
+
"""
|
24 |
+
# let's just go for the accuracies then.
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
main()
|
idiomify/metrics.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
you may want to include bleu score.
|
3 |
-
and more metrics for paraphrasing.
|
4 |
-
"""
|
|
|
|
|
|
|
|
|
|
idiomify/models.py
CHANGED
@@ -7,7 +7,7 @@ from torch.nn import functional as F
|
|
7 |
import pytorch_lightning as pl
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
from idiomify.builders import SourcesBuilder
|
10 |
-
|
11 |
|
12 |
class Idiomifier(pl.LightningModule): # noqa
|
13 |
"""
|
@@ -15,8 +15,11 @@ class Idiomifier(pl.LightningModule): # noqa
|
|
15 |
"""
|
16 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
17 |
super().__init__()
|
18 |
-
self.bart = bart
|
19 |
self.save_hyperparameters(ignore=["bart"])
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
22 |
"""
|
@@ -40,13 +43,27 @@ class Idiomifier(pl.LightningModule): # noqa
|
|
40 |
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
41 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
42 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
|
|
43 |
return {
|
44 |
"loss": loss
|
45 |
}
|
46 |
|
47 |
-
def on_train_batch_end(self, outputs: dict,
|
48 |
self.log("Train/Loss", outputs['loss'])
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
51 |
"""
|
52 |
Instantiates and returns the optimizer to be used for this model
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
from idiomify.builders import SourcesBuilder
|
10 |
+
from torchmetrics import Accuracy
|
11 |
|
12 |
class Idiomifier(pl.LightningModule): # noqa
|
13 |
"""
|
|
|
15 |
"""
|
16 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
17 |
super().__init__()
|
|
|
18 |
self.save_hyperparameters(ignore=["bart"])
|
19 |
+
self.bart = bart
|
20 |
+
# metrics (using accuracies as of right now)
|
21 |
+
self.acc_train = Accuracy(ignore_index=pad_token_id)
|
22 |
+
self.acc_test = Accuracy(ignore_index=pad_token_id)
|
23 |
|
24 |
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
25 |
"""
|
|
|
43 |
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
44 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
45 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
46 |
+
self.acc_train.update(logits.detach(), target=tgts.detach())
|
47 |
return {
|
48 |
"loss": loss
|
49 |
}
|
50 |
|
51 |
+
def on_train_batch_end(self, outputs: dict, **kwargs):
|
52 |
self.log("Train/Loss", outputs['loss'])
|
53 |
|
54 |
+
def on_train_epoch_end(self, *args, **kwargs) -> None:
|
55 |
+
self.log("Train/Accuracy", self.acc_train.compute())
|
56 |
+
self.acc_train.reset()
|
57 |
+
|
58 |
+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], **kwargs):
|
59 |
+
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
60 |
+
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
61 |
+
self.acc_test.update(logits.detach(), target=tgts.detach())
|
62 |
+
|
63 |
+
def on_test_end(self):
|
64 |
+
self.log("Test/Accuracy", self.acc_test.compute())
|
65 |
+
self.acc_test.reset()
|
66 |
+
|
67 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
68 |
"""
|
69 |
Instantiates and returns the optimizer to be used for this model
|
main_eval.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import wandb
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning.loggers import WandbLogger
|
7 |
+
from transformers import BartTokenizer
|
8 |
+
from idiomify.data import IdiomifyDataModule
|
9 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
10 |
+
from paths import ROOT_DIR
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
16 |
+
args = parser.parse_args()
|
17 |
+
config = fetch_config()['train']
|
18 |
+
config.update(vars(args))
|
19 |
+
# prepare the model
|
20 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
+
# prepare the datamodule
|
22 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
23 |
+
model = fetch_idiomifier(config['ver'], run)
|
24 |
+
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
+
logger = WandbLogger(log_model=False)
|
26 |
+
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
27 |
+
gpus=torch.cuda.device_count(),
|
28 |
+
default_root_dir=str(ROOT_DIR),
|
29 |
+
logger=logger)
|
30 |
+
trainer.test(model, datamodule)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
main()
|