tonyzhao6 commited on
Commit
c985775
1 Parent(s): fcae2b5

Upload gemma2_inference_hf.py

Browse files
Files changed (1) hide show
  1. gemma2_inference_hf.py +194 -0
gemma2_inference_hf.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains functionalities for running inference on Gemma 2 model
2
+ finetuned for urgency detection using the HuggingFace library.
3
+ """
4
+
5
+ # Standard Library
6
+ import ast
7
+
8
+ from textwrap import dedent
9
+ from typing import Any, Optional
10
+
11
+ # Third Party Library
12
+ import torch
13
+
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
15
+
16
+
17
+ def _construct_prompt(*, rules_list: list[str]) -> str:
18
+ """Construct the prompt for the finetuned model.
19
+
20
+ Parameters
21
+ ----------
22
+ rules_list
23
+ The list of urgency rules to match against the user message.
24
+
25
+ Returns
26
+ -------
27
+ str
28
+ The prompt for the finetuned model.
29
+ """
30
+
31
+ _prompt_base: str = dedent(
32
+ """
33
+ You are a highly sensitive urgency detector. Score if ANY part of the
34
+ user message corresponds to any part of the urgency rules provided below.
35
+ Ignore any part of the user message that does not correspond to the rules.
36
+ Respond with (a) the rule that is most consistent with the user message,
37
+ (b) the probability between 0 and 1 with increments of 0.1 that ANY part of
38
+ the user message matches the rule, and (c) the reason for the probability.
39
+
40
+
41
+ Respond in json string:
42
+
43
+ {
44
+ best_matching_rule: str
45
+ probability: float
46
+ reason: str
47
+ }
48
+ """
49
+ ).strip()
50
+ _prompt_rules: str = dedent(
51
+ """
52
+ Urgency Rules:
53
+ {urgency_rules}
54
+ """
55
+ ).strip()
56
+ urgency_rules_str = "\n".join(
57
+ [f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)]
58
+ )
59
+ prompt = (
60
+ _prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str)
61
+ )
62
+ return prompt
63
+
64
+
65
+ def get_completions(
66
+ *,
67
+ model,
68
+ rules_list: list[str],
69
+ skip_special_tokens_during_decode: bool = False,
70
+ text_generation_params: Optional[dict[str, Any]] = None,
71
+ tokenizer: PreTrainedTokenizerBase,
72
+ user_message: str,
73
+ ) -> dict[str, Any]:
74
+ """Get completions from the model for the given data.
75
+
76
+ Parameters
77
+ ----------
78
+ model
79
+ The model for inference.
80
+ rules_list
81
+ The list of urgency rules to match against the user message.
82
+ skip_special_tokens_during_decode
83
+ Specifies whether to skip special tokens during the decoding process.
84
+ text_generation_params
85
+ Dictionary containing text generation parameters for the LLM model. If not
86
+ specified, then default values will be used.
87
+ tokenizer
88
+ The tokenizer for the model.
89
+ user_message
90
+ The user message to match against the urgency rules.
91
+
92
+ Returns
93
+ -------
94
+ dict[str, Any]
95
+ The completion from the model. If the model output does not produce a valid
96
+ JSON string, then the original output is returned in the "generated_json" key.
97
+ """
98
+
99
+ assert all(x for x in rules_list), "Rules must be non-empty strings!"
100
+ text_generation_params = text_generation_params or {
101
+ "do_sample": True,
102
+ "eos_token_id": tokenizer.eos_token_id,
103
+ "max_new_tokens": 1024,
104
+ "num_return_sequences": 1,
105
+ "repetition_penalty": 1.1,
106
+ "temperature": 1e-6,
107
+ "top_p": 0.9,
108
+ }
109
+ tokenizer.add_special_tokens = False # Because we are using the chat template
110
+
111
+ start_of_turn, end_of_turn = tokenizer.additional_special_tokens
112
+ eos = tokenizer.eos_token
113
+ start_of_turn_model = f"{start_of_turn}model"
114
+ end_of_turn_model = f"{end_of_turn}{eos}"
115
+ input_ = (
116
+ _construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}"
117
+ )
118
+ chat = [{"role": "user", "content": input_}]
119
+ prompt = tokenizer.apply_chat_template(
120
+ chat, add_generation_prompt=True, tokenize=False
121
+ )
122
+ inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
123
+ outputs = model.generate(
124
+ input_ids=inputs.to(model.device), **text_generation_params
125
+ )
126
+ decoded_output = tokenizer.decode(
127
+ outputs[0], skip_special_tokens=skip_special_tokens_during_decode
128
+ )
129
+ completion_dict = {"user_message": user_message, "generated_json": decoded_output}
130
+ try:
131
+ start_of_turn_model_index = decoded_output.index(start_of_turn_model)
132
+ end_of_turn_model_index = decoded_output.index(end_of_turn_model)
133
+ generated_response = decoded_output[
134
+ start_of_turn_model_index
135
+ + len(start_of_turn_model) : end_of_turn_model_index
136
+ ].strip()
137
+ completion_dict["generated_json"] = ast.literal_eval(generated_response)
138
+ except (SyntaxError, ValueError):
139
+ pass
140
+ return completion_dict
141
+
142
+
143
+ if __name__ == "__main__":
144
+ DTYPE = torch.bfloat16
145
+ MODEL_ID = "idinsight/gemma-2-2b-it-ud"
146
+
147
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
148
+ tokenizer.pad_token = tokenizer.eos_token
149
+ tokenizer.padding_side = "right"
150
+
151
+ model = AutoModelForCausalLM.from_pretrained(
152
+ MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
153
+ )
154
+
155
+ text_generation_params = {
156
+ "do_sample": True,
157
+ "eos_token_id": tokenizer.eos_token_id,
158
+ "max_new_tokens": 1024,
159
+ "num_return_sequences": 1,
160
+ "repetition_penalty": 1.1,
161
+ "temperature": 1e-6,
162
+ "top_p": 0.9,
163
+ }
164
+
165
+ response = get_completions(
166
+ model=model,
167
+ rules_list=[
168
+ "NOT URGENT",
169
+ "Bleeding from the vagina",
170
+ "Bad tummy pain",
171
+ "Bad headache that won’t go away",
172
+ "Bad headache that won’t go away",
173
+ "Changes to vision",
174
+ "Trouble breathing",
175
+ "Hot or very cold, and very weak",
176
+ "Fits or uncontrolled shaking",
177
+ "Baby moves less",
178
+ "Fluid from the vagina",
179
+ "Feeding problems",
180
+ "Fits or uncontrolled shaking",
181
+ "Fast, slow or difficult breathing",
182
+ "Too hot or cold",
183
+ "Baby’s colour changes",
184
+ "Vomiting and watery poo",
185
+ "Infected belly button",
186
+ "Swollen or infected eyes",
187
+ "Bulging or sunken soft spot",
188
+ ],
189
+ skip_special_tokens_during_decode=False,
190
+ text_generation_params=text_generation_params,
191
+ tokenizer=tokenizer,
192
+ user_message="If my newborn can't able to breathe what can i do",
193
+ )
194
+ print(f"{response = }")