Text Classification
Transformers
Safetensors
English
HHEMv2Config
custom_code
ofermend commited on
Commit
e2b6d9e
1 Parent(s): 5838338

updated files for hhem-2.1-open

Browse files
README.md CHANGED
@@ -1,165 +1,95 @@
1
  ---
2
- license: apache-2.0
3
  language: en
4
- tags:
5
- - microsoft/deberta-v3-base
6
- datasets:
7
- - multi_nli
8
- - snli
9
- - fever
10
- - tals/vitaminc
11
- - paws
12
- metrics:
13
- - accuracy
14
- - auc
15
- - balanced accuracy
16
- pipeline_tag: text-classification
17
- widget:
18
- - text: "A man walks into a bar and buys a drink [SEP] A bloke swigs alcohol at a pub"
19
- example_title: "Positive"
20
- - text: "A boy is jumping on skateboard in the middle of a red bridge. [SEP] The boy skates down the sidewalk on a blue bridge"
21
- example_title: "Negative"
22
-
23
  ---
24
- <img src="candle.png" width="50" height="50" style="display: inline;"> In Loving memory of Simon Mark Hughes...
25
 
26
- # Introduction
27
- The HHEM model is an open source model, created by [Vectara](https://vectara.com), for detecting hallucinations in LLMs. It is particularly useful in the context of building retrieval-augmented-generation (RAG) applications where a set of facts is summarized by an LLM, but the model can also be used in other contexts.
 
 
28
 
29
  If you are interested to learn more about RAG or experiment with Vectara, you can [sign up](https://console.vectara.com/signup/?utm_source=huggingface&utm_medium=space&utm_term=hhem-model&utm_content=console&utm_campaign=) for a free Vectara account.
30
- Vectara now implements an improved version of HHEM which is a calibrated score with longer sequence length, and non-English language support. See more details [here](https://vectara.com/blog/automating-hallucination-detection-introducing-vectara-factual-consistency-score/)
31
- Now let's dive into the details of the model.
32
 
33
- ## Cross-Encoder for Hallucination Detection
34
- This model was trained using [SentenceTransformers](https://sbert.net) [Cross-Encoder](https://www.sbert.net/examples/applications/cross-encoder/README.html) class.
35
- The model outputs a probabilitity from 0 to 1, 0 being a hallucination and 1 being factually consistent.
36
- The predictions can be thresholded at 0.5 to predict whether a document is consistent with its source.
37
 
38
- ## Training Data
39
- This model is based on [microsoft/deberta-v3-base](https://huggingface.co/microsoft/deberta-v3-base) and is trained initially on NLI data to determine textual entailment, before being further fine tuned on summarization datasets with samples annotated for factual consistency including [FEVER](https://huggingface.co/datasets/fever), [Vitamin C](https://huggingface.co/datasets/tals/vitaminc) and [PAWS](https://huggingface.co/datasets/paws).
40
 
41
- ## Performance
42
 
43
- * [TRUE Dataset](https://arxiv.org/pdf/2204.04991.pdf) (Minus Vitamin C, FEVER and PAWS) - 0.872 AUC Score
44
- * [SummaC Benchmark](https://aclanthology.org/2022.tacl-1.10.pdf) (Test Split) - 0.764 Balanced Accuracy, 0.831 AUC Score
45
- * [AnyScale Ranking Test for Hallucinations](https://www.anyscale.com/blog/llama-2-is-about-as-factually-accurate-as-gpt-4-for-summaries-and-is-30x-cheaper) - 86.6 % Accuracy
46
 
47
- ## LLM Hallucination Leaderboard
48
- If you want to stay up to date with results of the latest tests using this model to evaluate the top LLM models, we have a [public leaderboard](https://huggingface.co/spaces/vectara/leaderboard) that is periodically updated, and results are also available on the [GitHub repository](https://github.com/vectara/hallucination-leaderboard).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Using HHEM
 
 
 
51
 
52
- ## Using the Inference API Widget on the Right
53
- To use the model with the widget, you need to pass both documents as a single string separated with [SEP]. For example:
54
 
55
- * A man walks into a bar and buys a drink [SEP] A bloke swigs alcohol at a pub
56
- * A person on a horse jumps over a broken down airplane. [SEP] A person is at a diner, ordering an omelette.
57
- * A person on a horse jumps over a broken down airplane. [SEP] A person is outdoors, on a horse.
58
 
59
- etc. See examples below for expected probability scores.
60
 
61
- ## Usage with Sentencer Transformers (Recommended)
62
 
63
- ### Inference
64
- The model can be used like this, on pairs of documents, passed as a list of list of strings (```List[List[str]]]```):
65
 
66
- ```python
67
- from sentence_transformers import CrossEncoder
68
-
69
- model = CrossEncoder('vectara/hallucination_evaluation_model')
70
- scores = model.predict([
71
- ["A man walks into a bar and buys a drink", "A bloke swigs alcohol at a pub"],
72
- ["A person on a horse jumps over a broken down airplane.", "A person is at a diner, ordering an omelette."],
73
- ["A person on a horse jumps over a broken down airplane.", "A person is outdoors, on a horse."],
74
- ["A boy is jumping on skateboard in the middle of a red bridge.", "The boy skates down the sidewalk on a blue bridge"],
75
- ["A man with blond-hair, and a brown shirt drinking out of a public water fountain.", "A blond drinking water in public."],
76
- ["A man with blond-hair, and a brown shirt drinking out of a public water fountain.", "A blond man wearing a brown shirt is reading a book."],
77
- ["Mark Wahlberg was a fan of Manny.", "Manny was a fan of Mark Wahlberg."],
78
- ])
79
- ```
80
 
81
- This returns a numpy array representing a factual consistency score. A score < 0.5 indicates a likely hallucination):
82
- ```
83
- array([0.61051559, 0.00047493709, 0.99639291, 0.00021221573, 0.99599433, 0.0014127002, 0.002.8262993], dtype=float32)
84
- ```
 
85
 
86
- Note that the model is designed to work with entire documents, so long as they fit into the 512 token context window (across both documents).
87
- Also note that the order of the documents is important, the first document is the source document, and the second document is validated against the first for factual consistency, e.g. as a summary of the first or a claim drawn from the source.
 
 
 
88
 
89
- ### Training
 
 
90
 
91
- ```python
92
- from sentence_transformers.cross_encoder import CrossEncoder
93
- from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
94
- from sentence_transformers import InputExample
95
-
96
- num_epochs = 5
97
- model_save_path = "./model_dump"
98
- model_name = 'cross-encoder/nli-deberta-v3-base' # base model, use 'vectara/hallucination_evaluation_model' if you want to further fine-tune ours
99
-
100
- model = CrossEncoder(model_name, num_labels=1, automodel_args={'ignore_mismatched_sizes':True})
101
-
102
- # Load some training examples as such, using a pandas dataframe with source and summary columns:
103
- train_examples, test_examples = [], []
104
- for i, row in df_train.iterrows():
105
- train_examples.append(InputExample(texts=[row['source'], row['summary']], label=int(row['label'])))
106
-
107
- for i, row in df_test.iterrows():
108
- test_examples.append(InputExample(texts=[row['source'], row['summary']], label=int(row['label'])))
109
- test_evaluator = CEBinaryClassificationEvaluator.from_input_examples(test_examples, name='test_eval')
110
-
111
- # Then train the model as such as per the Cross Encoder API:
112
- train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
113
- warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
114
- model.fit(train_dataloader=train_dataloader,
115
- evaluator=test_evaluator,
116
- epochs=num_epochs,
117
- evaluation_steps=10_000,
118
- warmup_steps=warmup_steps,
119
- output_path=model_save_path,
120
- show_progress_bar=True)
121
- ```
122
 
123
- ## Usage with Transformers AutoModel
124
- You can use the model also directly with Transformers library (without the SentenceTransformers library):
125
 
126
- ```python
127
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
128
- import torch
129
- import numpy as np
130
-
131
- model = AutoModelForSequenceClassification.from_pretrained('vectara/hallucination_evaluation_model')
132
- tokenizer = AutoTokenizer.from_pretrained('vectara/hallucination_evaluation_model')
133
-
134
- pairs = [
135
- ["A man walks into a bar and buys a drink", "A bloke swigs alcohol at a pub"],
136
- ["A person on a horse jumps over a broken down airplane.", "A person is at a diner, ordering an omelette."],
137
- ["A person on a horse jumps over a broken down airplane.", "A person is outdoors, on a horse."],
138
- ["A boy is jumping on skateboard in the middle of a red bridge.", "The boy skates down the sidewalk on a blue bridge"],
139
- ["A man with blond-hair, and a brown shirt drinking out of a public water fountain.", "A blond drinking water in public."],
140
- ["A man with blond-hair, and a brown shirt drinking out of a public water fountain.", "A blond man wearing a brown shirt is reading a book."],
141
- ["Mark Wahlberg was a fan of Manny.", "Manny was a fan of Mark Wahlberg."],
142
- ]
143
 
144
- inputs = tokenizer.batch_encode_plus(pairs, return_tensors='pt', padding=True)
145
 
146
- model.eval()
147
- with torch.no_grad():
148
- outputs = model(**inputs)
149
- logits = outputs.logits.cpu().detach().numpy()
150
- # convert logits to probabilities
151
- scores = 1 / (1 + np.exp(-logits)).flatten()
152
- ```
153
 
154
- This returns a numpy array representing a factual consistency score. A score < 0.5 indicates a likely hallucination):
155
- ```
156
- array([0.61051559, 0.00047493709, 0.99639291, 0.00021221573, 0.99599433, 0.0014127002, 0.002.8262993], dtype=float32)
157
- ```
158
 
159
- ## Contact Details
160
- Feel free to contact us with any questions:
161
- * X/Twitter - https://twitter.com/vectara or http://twitter.com/ofermend
162
- * Discussion [forums](https://discuss.vectara.com/)
163
- * Discord [server](https://discord.gg/GFb8gMz6UH)
164
 
165
- For more information about [Vectara](https://vectara.com) and how to use our RAG-as-a-service API platform, check out our [documentation](https://docs.vectara.com/docs/).
 
 
1
  ---
 
2
  language: en
3
+ license: apache-2.0
4
+ base_model: google/flan-t5-base
5
+ pipline_tag: text-classficiation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  ---
 
7
 
8
+ <img src="https://huggingface.co/vectara/hallucination_evaluation_model/resolve/main/candle.png" width="50" height="50" style="display: inline;"> In Loving memory of Simon Mark Hughes...
9
+
10
+
11
+ HHEM-2.1-Open is a major upgrade to HHEM-1.0-Open created by [Vectara](https://vectara.com) in November 2023. The HHEM model series are designed for detecting hallucinations in LLMs. They are particularly useful in the context of building retrieval-augmented-generation (RAG) applications where a set of facts is summarized by an LLM, and HHEM can be used to measure the extent to which this summary is factually consistent with the facts.
12
 
13
  If you are interested to learn more about RAG or experiment with Vectara, you can [sign up](https://console.vectara.com/signup/?utm_source=huggingface&utm_medium=space&utm_term=hhem-model&utm_content=console&utm_campaign=) for a free Vectara account.
 
 
14
 
15
+ ## Hallucination Detection 101
16
+ By "hallucinated" or "factually inconsistent", we mean that a text (hypothesis, to be judged) is not supported by another text (evidence/premise, given). You **always need two** pieces of text to determine whether a text is hallucinated or not. When applied to RAG (retrieval augmented generation), the LLM is provided by several pieces of text (often called facts or context) retrieved from some dataset, and a hallucination would indicate that the summary (hypothesis) is not supported by those facts (evidence).
 
 
17
 
18
+ A common type of hallucination in RAG is **factual but hallucinated**. For example, given the premise _"The capital of France is Berlin"_, the hypothesis _"The capital of France is Paris"_ is hallucinated -- although it is true in the world knowledge. This happens when LLMs do not generate content based on the textual data provided to them as part of the RAG retrieval process, but rather generate content based on their pre-trained knowledge.
 
19
 
20
+ ## Using HHEM-2.1-Open
21
 
22
+ HHEM-2.1-Open can be loaded easily using the `transformers` library. Just remember to set `trust_remote_code=True` to take advantage of the pre-/post-processing code we provided for your convenience. The **input** of the model is a list of pairs of (premise, hypothesis). For each pair, the model will **return** a score between 0 and 1, where 0 means that the hypothesis is not evidenced at all by the premise and 1 means the hypothesis is fully supported by the premise.
 
 
23
 
24
+ ```python
25
+ from transformers import AutoModelForSequenceClassification
26
+
27
+ # Load the model
28
+ model = AutoModelForSequenceClassification.from_pretrained(
29
+ 'vectara/hallucination_evaluation_model', trust_remote_code=True)
30
+
31
+ pairs = [ # Test data, List[Tuple[str, str]]
32
+ ("The capital of France is Berlin.", "The capital of France is Paris."), # factual but hallucinated
33
+ ('I am in California', 'I am in United States.'), # Consistent
34
+ ('I am in United States', 'I am in California.'), # Hallucinated
35
+ ("A person on a horse jumps over a broken down airplane.", "A person is outdoors, on a horse."),
36
+ ("A boy is jumping on skateboard in the middle of a red bridge.", "The boy skates down the sidewalk on a red bridge"),
37
+ ("A man with blond-hair, and a brown shirt drinking out of a public water fountain.", "A blond man wearing a brown shirt is reading a book."),
38
+ ("Mark Wahlberg was a fan of Manny.", "Manny was a fan of Mark Wahlberg.")
39
+ ]
40
 
41
+ # Use the model to predict
42
+ model.predict(pairs) # note the predict() method. Do not do model(pairs).
43
+ # tensor([0.0111, 0.6474, 0.1290, 0.8969, 0.1846, 0.0050, 0.0543])
44
+ ```
45
 
46
+ Note that the order of a pair is important. For example, notice how the 2nd and 3rd examples in the `pairs` list are consistent and hallcuianted, respectively.
 
47
 
 
 
 
48
 
49
+ ## HHEM-2.1-Open vs. HHEM-1.0
50
 
51
+ The major difference between HHEM-2.1-Open and the original HHEM-1.0 is that HHEM-2.1-Open has an unlimited context length, while HHEM-1.0 is capped at 512 tokens. The longer context length allows HHEM-2.1-Open to provide more accurate hallucination detection for RAG which often needs more than 512 tokens.
52
 
53
+ The tables below compare the two models on the [AggreFact](https://arxiv.org/pdf/2205.12854) and [RAGTruth](https://arxiv.org/abs/2401.00396) benchmarks. In particualr, on AggreFact, we focus on its SOTA subset (denoted as `AggreFact-SOTA`) which contains summaries generated by Google's T5, Meta's BART, and Google's Pegasus, which are the three latest models in the AggreFact benchmark. The results on RAGTruth's summarization (denoted as `RAGTruth-Summ`) and QA (denoted as `RAGTruth-QA`) subsets are reported separately.
 
54
 
55
+ Table 1: Performance on AggreFact-SOTA
56
+ | model | Balanced Accuracy | F1 | Recall | Precision |
57
+ |:----------------------|---------:|-------:|-------:|----------:|
58
+ | HHEM-1.0 | 0.7887 | 0.9047 | 0.7081 | 0.6728 |
59
+ | HHEM-2.1-Open | 0.7655 | 0.6677 | 0.6848 | 0.6513 |
 
 
 
 
 
 
 
 
 
60
 
61
+ Table 2: Performance on RAGTruth-Summ
62
+ | model | Balanced Accuracy | F1 | Recall | Precision |
63
+ |:----------------------|---------:|-----------:|----------:|----------:|
64
+ | HHEM-1.0 | 0.5336 | 0.1577 | 0.0931 | 0.5135 |
65
+ | HHEM-2.1-Open | 0.6442 | 0.4883 | 0.3186 | 0.7558 |
66
 
67
+ Table 3: Performance on RAGTruth-QA
68
+ | model | Balanced Accuracy | F1 | Recall | Precision |
69
+ |:----------------------|---------:|-----------:|----------:|----------:|
70
+ | HHEM-1.0 | 0.5258 | 0.1940 | 0.1625 | 0.2407 |
71
+ | HHEM-2.1-Open | 0.7428 | 0.6000 | 0.5438 | 0.6692 |
72
 
73
+ The tables above show that HHEM-2.1-Open has a significant improvement over HHEM-1.0 in the RAGTruth-Summ and RAGTruth-QA benchmarks, while it has a slight decrease in the AggreFact-SOTA benchmark. However when intepreting these results, please note that AggreFact-SOTA is evaluated on relatively older types of LLMs:
74
+ - LLMs in AggreFact-SOTA: T5, BART, and Pegasus;
75
+ - LLMs in RAGTruth: GPT-4-0613, GPT-3.5-turbo-0613, Llama-2-7B/13B/70B-chat, and Mistral-7B-instruct.
76
 
77
+ Therefore, we conclude that HHEM-2.1-Open is better than HHEM-1.0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ ## Want something more powerful?
 
80
 
81
+ As you may have already sensed from the name, HHEM-2.1-Open is the open source version of the premium HHEM-2.1. HHEM-2.1 (without the `-Open`) is offered exclusively via Vectara's RAG-as-a-service platform. The major difference between HHEM-2.1 and HHEM-2.1-Open is that HHEM-2.1 is cross-lingual on three languages: English, German, and French, while HHEM-2.1-Open is English-only. "Cross-lingual" means any combination of the three languages, e.g., documents in German, query in English, results in French.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ ### Why RAG in Vectara?
84
 
85
+ Vectara provides a Trusted Generative AI platform. The platform allows organizations to rapidly create an AI assistant experience which is grounded in the data, documents, and knowledge that they have. Vectara's serverless RAG-as-a-Service also solves critical problems required for enterprise adoption, namely: reduces hallucination, provides explainability / provenance, enforces access control, allows for real-time updatability of the knowledge, and mitigates intellectual property / bias concerns from large language models.
 
 
 
 
 
 
86
 
87
+ To start benefiting from HHEM-2.1, you can [sign up](https://console.vectara.com/signup/?utm_source=huggingface&utm_medium=space&utm_term=hhem-model&utm_content=console&utm_campaign=) for a free Vectara account, and you will get the HHEM-2.1 score returned with every query automatically.
 
 
 
88
 
89
+ Here are some additional resources:
90
+ 1. Vectara [API documentation](https://docs.vectara.com/docs).
91
+ 2. Quick start using Forrest's [vectara-python-cli](https://vectara-python-cli.readthedocs.io/en/latest/crash_course.html).
92
+ 3. Learn more about Vectara's [Boomerang embedding model](https://vectara.com/blog/introducing-boomerang-vectaras-new-and-improved-retrieval-model/), [Slingshot reranker](https://vectara.com/blog/deep-dive-into-vectara-multilingual-reranker-v1-state-of-the-art-reranker-across-100-languages/), and [Mockingbird LLM](https://vectara.com/blog/mockingbird-a-rag-and-structured-output-focused-llm/)
 
93
 
94
+ ## LLM Hallucination Leaderboard
95
+ If you want to stay up to date with results of the latest tests using this model to evaluate the top LLM models, we have a [public leaderboard](https://huggingface.co/spaces/vectara/leaderboard) that is periodically updated, and results are also available on the [GitHub repository](https://github.com/vectara/hallucination-leaderboard).
added_tokens.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "[MASK]": 128000
3
- }
 
 
 
 
candle.png DELETED
Binary file (487 kB)
 
config.json CHANGED
@@ -1,41 +1,12 @@
1
  {
2
- "_name_or_path": "./cross-encoder-binary-bce_loss/distilbert_nli_epoch_crossencoder_binary-bce-3-score_0.9844",
3
  "architectures": [
4
- "DebertaV2ForSequenceClassification"
5
  ],
6
- "attention_probs_dropout_prob": 0.1,
7
- "hidden_act": "gelu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 768,
10
- "id2label": {
11
- "0": "LABEL_0"
12
  },
13
- "initializer_range": 0.02,
14
- "intermediate_size": 3072,
15
- "label2id": {
16
- "LABEL_0": 0
17
- },
18
- "layer_norm_eps": 1e-07,
19
- "max_position_embeddings": 512,
20
- "max_relative_positions": -1,
21
- "model_type": "deberta-v2",
22
- "norm_rel_ebd": "layer_norm",
23
- "num_attention_heads": 12,
24
- "num_hidden_layers": 12,
25
- "pad_token_id": 0,
26
- "pooler_dropout": 0,
27
- "pooler_hidden_act": "gelu",
28
- "pooler_hidden_size": 768,
29
- "pos_att_type": [
30
- "p2c",
31
- "c2p"
32
- ],
33
- "position_biased_input": false,
34
- "position_buckets": 256,
35
- "relative_attention": true,
36
- "share_att_key": true,
37
  "torch_dtype": "float32",
38
- "transformers_version": "4.33.3",
39
- "type_vocab_size": 0,
40
- "vocab_size": 128100
41
  }
 
1
  {
 
2
  "architectures": [
3
+ "HHEMv2ForSequenceClassification"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_hhem_v2.HHEMv2Config",
7
+ "AutoModelForSequenceClassification": "modeling_hhem_v2.HHEMv2ForSequenceClassification"
 
 
 
8
  },
9
+ "model_type": "HHEMv2Config",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  "torch_dtype": "float32",
11
+ "transformers_version": "4.39.3"
 
 
12
  }
configuration_hhem_v2.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class HHEMv2Config(PretrainedConfig):
4
+ model_type = "HHEMv2"
5
+ foundation = "google/flan-t5-base"
6
+ prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
7
+
8
+ def __init___(self,
9
+ foundation="xyz",
10
+ prompt="abc",
11
+ **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.foundation = foundation
14
+ self.prompt = prompt
15
+
16
+
17
+ # FIXME: The default values passed to the constructor are not used.
18
+ # Instead, the values set as global before the constructor are used.
19
+ # To test, run this:
20
+ # config = HHEMv2Config()
21
+ # print(config.foundation)
22
+ # The output will not be xyz but google/flan-t5-base.
leaderboard_summaries.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab3c650afa79db152974ad311258fb573b1815a020b6f7e8a4c3a8636eb27487
3
- size 19446482
 
 
 
 
spm.model → model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
- size 2464616
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:634de18a38cf1e991c1acd0f7a9e0d30f7ea187fba42bb4798f862d3edd31e72
3
+ size 438535352
modeling_hhem_v2.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from peft import PeftModel
4
+ from transformers import PreTrainedModel, AutoConfig, T5ForTokenClassification, AutoModel, AutoTokenizer, AutoModelForTokenClassification
5
+
6
+ from .configuration_hhem_v2 import HHEMv2Config
7
+
8
+ class HHEMv2Model(PreTrainedModel):
9
+ config_class = HHEMv2Config
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ # self.t5 = T5ForTokenClassification.from_config(
14
+ # AutoConfig.from_pretrained(config.foundation)
15
+ # )
16
+
17
+ # def populate(self, model):
18
+ # self.t5 = model
19
+
20
+ # def forward(self, **kwarg):
21
+ # return self.t5.transformer(**kwarg)
22
+
23
+ class HHEMv2ForSequenceClassification(PreTrainedModel):
24
+ config_class = HHEMv2Config
25
+
26
+ def __init__(self, config=HHEMv2Config()):
27
+ super().__init__(config)
28
+ self.t5 = T5ForTokenClassification(
29
+ AutoConfig.from_pretrained(config.foundation)
30
+ )
31
+ self.prompt = config.prompt
32
+ self.tokenzier = AutoTokenizer.from_pretrained(config.foundation)
33
+
34
+ def populate(self, model: AutoModel):
35
+ """Initiate the model with the pretrained model
36
+
37
+ This method should only be called by Vectara employee who prepares the model for publishing. Users do not need to call this method.
38
+
39
+ """
40
+ self.t5 = model
41
+
42
+ # TODO: Figure out how to publish only the adapter yet still able to do end-to-end pulling and inference.
43
+ # def populate_lora(self, checkpoint: str):
44
+ # base_model = AutoModelForTokenClassification.from_pretrained(self.config.foundation)
45
+ # combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
46
+ # self.t5 = combined_model
47
+
48
+ def forward(self, **kwargs):
49
+ return self.t5(**kwargs)
50
+
51
+ def predict(self, text_pairs):
52
+ tokenizer = self.tokenzier
53
+ pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs]
54
+ inputs = tokenizer(
55
+ [self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True)
56
+ self.t5.eval()
57
+ with torch.no_grad():
58
+ outputs = self.t5(**inputs)
59
+ logits = outputs.logits
60
+ logits = logits[:, 0, :] # tok_cls
61
+ transformed_probs = torch.softmax(logits, dim=-1)
62
+ raw_scores = transformed_probs[:, 1] # the probability of class 1
63
+ return raw_scores
64
+
65
+
pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb3789996c599ab27a5506005f03229a0a4600fe3634d121c48731ae95cde40e
3
- size 737761269
 
 
 
 
special_tokens_map.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "bos_token": "[CLS]",
3
- "cls_token": "[CLS]",
4
- "eos_token": "[SEP]",
5
- "mask_token": "[MASK]",
6
- "pad_token": "[PAD]",
7
- "sep_token": "[SEP]",
8
- "unk_token": "[UNK]"
9
- }
 
 
 
 
 
 
 
 
 
 
tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "bos_token": "[CLS]",
3
- "clean_up_tokenization_spaces": true,
4
- "cls_token": "[CLS]",
5
- "do_lower_case": false,
6
- "eos_token": "[SEP]",
7
- "mask_token": "[MASK]",
8
- "max_length": 512,
9
- "model_max_length": 512,
10
- "pad_to_multiple_of": null,
11
- "pad_token": "[PAD]",
12
- "pad_token_type_id": 0,
13
- "padding_side": "right",
14
- "sep_token": "[SEP]",
15
- "sp_model_kwargs": {},
16
- "split_by_punct": false,
17
- "stride": 0,
18
- "tokenizer_class": "DebertaV2Tokenizer",
19
- "truncation_side": "right",
20
- "truncation_strategy": "longest_first",
21
- "unk_token": "[UNK]",
22
- "vocab_type": "spm"
23
- }