gemma-2-2b-it-ud / gemma2_inference_hf.py
tonyzhao6's picture
Upload gemma2_inference_hf.py
c985775 verified
"""This module contains functionalities for running inference on Gemma 2 model
finetuned for urgency detection using the HuggingFace library.
"""
# Standard Library
import ast
from textwrap import dedent
from typing import Any, Optional
# Third Party Library
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
def _construct_prompt(*, rules_list: list[str]) -> str:
"""Construct the prompt for the finetuned model.
Parameters
----------
rules_list
The list of urgency rules to match against the user message.
Returns
-------
str
The prompt for the finetuned model.
"""
_prompt_base: str = dedent(
"""
You are a highly sensitive urgency detector. Score if ANY part of the
user message corresponds to any part of the urgency rules provided below.
Ignore any part of the user message that does not correspond to the rules.
Respond with (a) the rule that is most consistent with the user message,
(b) the probability between 0 and 1 with increments of 0.1 that ANY part of
the user message matches the rule, and (c) the reason for the probability.
Respond in json string:
{
best_matching_rule: str
probability: float
reason: str
}
"""
).strip()
_prompt_rules: str = dedent(
"""
Urgency Rules:
{urgency_rules}
"""
).strip()
urgency_rules_str = "\n".join(
[f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)]
)
prompt = (
_prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str)
)
return prompt
def get_completions(
*,
model,
rules_list: list[str],
skip_special_tokens_during_decode: bool = False,
text_generation_params: Optional[dict[str, Any]] = None,
tokenizer: PreTrainedTokenizerBase,
user_message: str,
) -> dict[str, Any]:
"""Get completions from the model for the given data.
Parameters
----------
model
The model for inference.
rules_list
The list of urgency rules to match against the user message.
skip_special_tokens_during_decode
Specifies whether to skip special tokens during the decoding process.
text_generation_params
Dictionary containing text generation parameters for the LLM model. If not
specified, then default values will be used.
tokenizer
The tokenizer for the model.
user_message
The user message to match against the urgency rules.
Returns
-------
dict[str, Any]
The completion from the model. If the model output does not produce a valid
JSON string, then the original output is returned in the "generated_json" key.
"""
assert all(x for x in rules_list), "Rules must be non-empty strings!"
text_generation_params = text_generation_params or {
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
"max_new_tokens": 1024,
"num_return_sequences": 1,
"repetition_penalty": 1.1,
"temperature": 1e-6,
"top_p": 0.9,
}
tokenizer.add_special_tokens = False # Because we are using the chat template
start_of_turn, end_of_turn = tokenizer.additional_special_tokens
eos = tokenizer.eos_token
start_of_turn_model = f"{start_of_turn}model"
end_of_turn_model = f"{end_of_turn}{eos}"
input_ = (
_construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}"
)
chat = [{"role": "user", "content": input_}]
prompt = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(model.device), **text_generation_params
)
decoded_output = tokenizer.decode(
outputs[0], skip_special_tokens=skip_special_tokens_during_decode
)
completion_dict = {"user_message": user_message, "generated_json": decoded_output}
try:
start_of_turn_model_index = decoded_output.index(start_of_turn_model)
end_of_turn_model_index = decoded_output.index(end_of_turn_model)
generated_response = decoded_output[
start_of_turn_model_index
+ len(start_of_turn_model) : end_of_turn_model_index
].strip()
completion_dict["generated_json"] = ast.literal_eval(generated_response)
except (SyntaxError, ValueError):
pass
return completion_dict
if __name__ == "__main__":
DTYPE = torch.bfloat16
MODEL_ID = "idinsight/gemma-2-2b-it-ud"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
)
text_generation_params = {
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
"max_new_tokens": 1024,
"num_return_sequences": 1,
"repetition_penalty": 1.1,
"temperature": 1e-6,
"top_p": 0.9,
}
response = get_completions(
model=model,
rules_list=[
"NOT URGENT",
"Bleeding from the vagina",
"Bad tummy pain",
"Bad headache that won’t go away",
"Bad headache that won’t go away",
"Changes to vision",
"Trouble breathing",
"Hot or very cold, and very weak",
"Fits or uncontrolled shaking",
"Baby moves less",
"Fluid from the vagina",
"Feeding problems",
"Fits or uncontrolled shaking",
"Fast, slow or difficult breathing",
"Too hot or cold",
"Baby’s colour changes",
"Vomiting and watery poo",
"Infected belly button",
"Swollen or infected eyes",
"Bulging or sunken soft spot",
],
skip_special_tokens_during_decode=False,
text_generation_params=text_generation_params,
tokenizer=tokenizer,
user_message="If my newborn can't able to breathe what can i do",
)
print(f"{response = }")