|
"""This module contains functionalities for running inference on Gemma 2 model |
|
finetuned for urgency detection using the HuggingFace library. |
|
""" |
|
|
|
|
|
import ast |
|
|
|
from textwrap import dedent |
|
from typing import Any, Optional |
|
|
|
|
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase |
|
|
|
|
|
def _construct_prompt(*, rules_list: list[str]) -> str: |
|
"""Construct the prompt for the finetuned model. |
|
|
|
Parameters |
|
---------- |
|
rules_list |
|
The list of urgency rules to match against the user message. |
|
|
|
Returns |
|
------- |
|
str |
|
The prompt for the finetuned model. |
|
""" |
|
|
|
_prompt_base: str = dedent( |
|
""" |
|
You are a highly sensitive urgency detector. Score if ANY part of the |
|
user message corresponds to any part of the urgency rules provided below. |
|
Ignore any part of the user message that does not correspond to the rules. |
|
Respond with (a) the rule that is most consistent with the user message, |
|
(b) the probability between 0 and 1 with increments of 0.1 that ANY part of |
|
the user message matches the rule, and (c) the reason for the probability. |
|
|
|
|
|
Respond in json string: |
|
|
|
{ |
|
best_matching_rule: str |
|
probability: float |
|
reason: str |
|
} |
|
""" |
|
).strip() |
|
_prompt_rules: str = dedent( |
|
""" |
|
Urgency Rules: |
|
{urgency_rules} |
|
""" |
|
).strip() |
|
urgency_rules_str = "\n".join( |
|
[f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)] |
|
) |
|
prompt = ( |
|
_prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str) |
|
) |
|
return prompt |
|
|
|
|
|
def get_completions( |
|
*, |
|
model, |
|
rules_list: list[str], |
|
skip_special_tokens_during_decode: bool = False, |
|
text_generation_params: Optional[dict[str, Any]] = None, |
|
tokenizer: PreTrainedTokenizerBase, |
|
user_message: str, |
|
) -> dict[str, Any]: |
|
"""Get completions from the model for the given data. |
|
|
|
Parameters |
|
---------- |
|
model |
|
The model for inference. |
|
rules_list |
|
The list of urgency rules to match against the user message. |
|
skip_special_tokens_during_decode |
|
Specifies whether to skip special tokens during the decoding process. |
|
text_generation_params |
|
Dictionary containing text generation parameters for the LLM model. If not |
|
specified, then default values will be used. |
|
tokenizer |
|
The tokenizer for the model. |
|
user_message |
|
The user message to match against the urgency rules. |
|
|
|
Returns |
|
------- |
|
dict[str, Any] |
|
The completion from the model. If the model output does not produce a valid |
|
JSON string, then the original output is returned in the "generated_json" key. |
|
""" |
|
|
|
assert all(x for x in rules_list), "Rules must be non-empty strings!" |
|
text_generation_params = text_generation_params or { |
|
"do_sample": True, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
"max_new_tokens": 1024, |
|
"num_return_sequences": 1, |
|
"repetition_penalty": 1.1, |
|
"temperature": 1e-6, |
|
"top_p": 0.9, |
|
} |
|
tokenizer.add_special_tokens = False |
|
|
|
start_of_turn, end_of_turn = tokenizer.additional_special_tokens |
|
eos = tokenizer.eos_token |
|
start_of_turn_model = f"{start_of_turn}model" |
|
end_of_turn_model = f"{end_of_turn}{eos}" |
|
input_ = ( |
|
_construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}" |
|
) |
|
chat = [{"role": "user", "content": input_}] |
|
prompt = tokenizer.apply_chat_template( |
|
chat, add_generation_prompt=True, tokenize=False |
|
) |
|
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") |
|
outputs = model.generate( |
|
input_ids=inputs.to(model.device), **text_generation_params |
|
) |
|
decoded_output = tokenizer.decode( |
|
outputs[0], skip_special_tokens=skip_special_tokens_during_decode |
|
) |
|
completion_dict = {"user_message": user_message, "generated_json": decoded_output} |
|
try: |
|
start_of_turn_model_index = decoded_output.index(start_of_turn_model) |
|
end_of_turn_model_index = decoded_output.index(end_of_turn_model) |
|
generated_response = decoded_output[ |
|
start_of_turn_model_index |
|
+ len(start_of_turn_model) : end_of_turn_model_index |
|
].strip() |
|
completion_dict["generated_json"] = ast.literal_eval(generated_response) |
|
except (SyntaxError, ValueError): |
|
pass |
|
return completion_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
DTYPE = torch.bfloat16 |
|
MODEL_ID = "idinsight/gemma-2-2b-it-ud" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE |
|
) |
|
|
|
text_generation_params = { |
|
"do_sample": True, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
"max_new_tokens": 1024, |
|
"num_return_sequences": 1, |
|
"repetition_penalty": 1.1, |
|
"temperature": 1e-6, |
|
"top_p": 0.9, |
|
} |
|
|
|
response = get_completions( |
|
model=model, |
|
rules_list=[ |
|
"NOT URGENT", |
|
"Bleeding from the vagina", |
|
"Bad tummy pain", |
|
"Bad headache that won’t go away", |
|
"Bad headache that won’t go away", |
|
"Changes to vision", |
|
"Trouble breathing", |
|
"Hot or very cold, and very weak", |
|
"Fits or uncontrolled shaking", |
|
"Baby moves less", |
|
"Fluid from the vagina", |
|
"Feeding problems", |
|
"Fits or uncontrolled shaking", |
|
"Fast, slow or difficult breathing", |
|
"Too hot or cold", |
|
"Baby’s colour changes", |
|
"Vomiting and watery poo", |
|
"Infected belly button", |
|
"Swollen or infected eyes", |
|
"Bulging or sunken soft spot", |
|
], |
|
skip_special_tokens_during_decode=False, |
|
text_generation_params=text_generation_params, |
|
tokenizer=tokenizer, |
|
user_message="If my newborn can't able to breathe what can i do", |
|
) |
|
print(f"{response = }") |
|
|