metadata
datasets:
- esnli
license: apache-2.0
language:
- en
metrics:
- accuracy
- f1
- precision
- recall
model-index:
- name: backpack-gpt2-nli
results:
- task:
name: Natural Language Inference
type: text-classification
dataset:
name: e-SNLI
type: esnli
split: validation
metrics:
- name: Accuracy
type: accuracy
value: 0.9006299532615322
- name: F1
type: f1
value: 0.9004261302857443
- name: Precision
type: precision
value: 0.9004584180714215
- name: Recall
type: recall
value: 0.9004554220756779
- task:
name: Natural Language Inference
type: text-classification
dataset:
name: e-SNLI
type: esnli
split: test
metrics:
- name: Accuracy
type: accuracy
value: 0.8957654723127035
- name: F1
type: f1
value: 0.8954702227331482
- name: Precision
type: precision
value: 0.8954036872157838
- name: Recall
type: recall
value: 0.8955997285576146
pipeline_tag: text-classification
tags:
- Natural Language Inference
- Sequence Classification
- GPT2
- Backpack
- ESNLI
Model Card for Backpack-GPT2-NLI
This is a fine-tuned version of backpack-gpt2 with a NLI classification head on the esnli dataset. Results:
- On Validation Set:
- CrossEntropyLoss: 0.3168
- Accuracy: 0.9006
- F1: 0.9004
- On Test Set:
- CrossEntropyLoss: 0.3277
- Accuracy: 0.8958
- F1: 0.8955
Model Description
- Developed by: Erfan Moosavi Monazzah
- Model type: Sequence Classifier
- Language(s) (NLP): English
- License: apache-2.0
- Finetuned from model [optional]: Backpack-GPT2
How to Get Started with the Model
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
concatenated_sentences = [f'{premise.strip(".")}. ^ {hypothesis.strip(".")}.' for premise, hypothesis in zip(examples['premise'], examples['hypothesis'])]
tokenized_inputs = tokenizer(
concatenated_sentences,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
return tokenized_inputs
model = AutoModelForSequenceClassification.from_pretrained('ErfanMoosaviMonazzah/backpack-gpt2-nli', trust_remote_code=True)
model.eval()
tokenized_sent = tokenize_function({
'premise':['A boy is jumping on skateboard in the middle of a red bridge.',
'Two women who just had lunch hugging and saying goodbye.',
'Children smiling and waving at camera'],
'hypothesis':['The boy does a skateboarding trick.',
'The friends have just met for the first time in 20 years, and have had a great time catching up.',
'The kids are frowning']
})
model.predict(input_ids=tokenized_sent['input_ids'], attention_mask=tokenized_sent['attention_mask'])
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-5
- train_batch_size: 64
- eval_batch_size: 64
- seed: 2023
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0
- num_epochs: 3
Training results
Step | Training Loss | Validation Loss | Precision | Recall | F1 | Accuracy |
---|---|---|---|---|---|---|
512 | 0.614900 | 0.463713 | 0.826792 | 0.824639 | 0.825133 | 0.824731 |
1024 | 0.503300 | 0.431796 | 0.844831 | 0.839414 | 0.839980 | 0.839565 |
1536 | 0.475600 | 0.400771 | 0.848741 | 0.847009 | 0.846287 | 0.847795 |
2048 | 0.455900 | 0.375981 | 0.859064 | 0.857357 | 0.857749 | 0.857448 |
2560 | 0.440400 | 0.365537 | 0.862000 | 0.862078 | 0.861917 | 0.862426 |
3072 | 0.433100 | 0.365180 | 0.864717 | 0.859693 | 0.860237 | 0.859785 |
3584 | 0.425100 | 0.346340 | 0.872312 | 0.870635 | 0.870865 | 0.870961 |
4096 | 0.413300 | 0.343761 | 0.873606 | 0.873046 | 0.873174 | 0.873298 |
4608 | 0.412000 | 0.344890 | 0.882609 | 0.882120 | 0.882255 | 0.882341 |
5120 | 0.402600 | 0.336744 | 0.876463 | 0.875629 | 0.875827 | 0.875737 |
5632 | 0.390600 | 0.323248 | 0.882598 | 0.880779 | 0.881129 | 0.880817 |
6144 | 0.388300 | 0.338029 | 0.877255 | 0.877041 | 0.877126 | 0.877261 |
6656 | 0.390800 | 0.333301 | 0.876357 | 0.876362 | 0.875965 | 0.876753 |
7168 | 0.383800 | 0.328297 | 0.883593 | 0.883675 | 0.883629 | 0.883967 |
7680 | 0.380800 | 0.331854 | 0.882362 | 0.880373 | 0.880764 | 0.880512 |
8192 | 0.368400 | 0.323076 | 0.881730 | 0.881378 | 0.881419 | 0.881528 |
8704 | 0.367000 | 0.313959 | 0.889204 | 0.889047 | 0.889053 | 0.889352 |
9216 | 0.315600 | 0.333637 | 0.885518 | 0.883965 | 0.884266 | 0.883967 |
9728 | 0.303100 | 0.319416 | 0.888667 | 0.888092 | 0.888256 | 0.888234 |
10240 | 0.307200 | 0.317827 | 0.887575 | 0.887647 | 0.887418 | 0.888031 |
10752 | 0.300100 | 0.311810 | 0.890908 | 0.890827 | 0.890747 | 0.891181 |
11264 | 0.303400 | 0.311010 | 0.889871 | 0.887939 | 0.888309 | 0.887929 |
11776 | 0.300500 | 0.309282 | 0.891041 | 0.889819 | 0.890077 | 0.889860 |
12288 | 0.303600 | 0.326918 | 0.891272 | 0.891250 | 0.890942 | 0.891689 |
12800 | 0.300300 | 0.301688 | 0.894516 | 0.894619 | 0.894481 | 0.894940 |
13312 | 0.302200 | 0.302173 | 0.896441 | 0.896527 | 0.896462 | 0.896769 |
13824 | 0.299800 | 0.293489 | 0.895047 | 0.895172 | 0.895084 | 0.895448 |
14336 | 0.294600 | 0.297645 | 0.895865 | 0.896012 | 0.895886 | 0.896261 |
14848 | 0.296700 | 0.300751 | 0.895277 | 0.895401 | 0.895304 | 0.895651 |
15360 | 0.293100 | 0.293049 | 0.896855 | 0.896705 | 0.896757 | 0.896871 |
15872 | 0.293600 | 0.294201 | 0.895933 | 0.895557 | 0.895624 | 0.895651 |
16384 | 0.290100 | 0.289367 | 0.897847 | 0.897889 | 0.897840 | 0.898090 |
16896 | 0.293600 | 0.283990 | 0.898889 | 0.898724 | 0.898789 | 0.898903 |
17408 | 0.285800 | 0.308257 | 0.898250 | 0.898102 | 0.898162 | 0.898293 |
17920 | 0.252400 | 0.327164 | 0.898860 | 0.898807 | 0.898831 | 0.899004 |
18432 | 0.219500 | 0.315286 | 0.898877 | 0.898835 | 0.898831 | 0.899004 |
18944 | 0.217900 | 0.312738 | 0.898857 | 0.898958 | 0.898886 | 0.899207 |
19456 | 0.186400 | 0.320669 | 0.899252 | 0.899166 | 0.899194 | 0.899411 |
19968 | 0.199000 | 0.316840 | 0.900458 | 0.900455 | 0.900426 | 0.900630 |