|
--- |
|
license: apache-2.0 |
|
tags: |
|
- trl |
|
- ppo |
|
- transformers |
|
- reinforcement-learning |
|
language: |
|
- en |
|
--- |
|
|
|
# TRL Model |
|
|
|
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to |
|
guide the model outputs according to a simulated human feedback. The model was fine-tuned for classification of cancer / diabetes based on clinical notes. |
|
|
|
```bash |
|
pip install torch transformers trl peft |
|
``` |
|
|
|
## Usage |
|
|
|
```python |
|
from transformers import AutoTokenizer |
|
from trl import AutoModelForCausalLMWithValueHead |
|
from peft import LoraConfig |
|
|
|
tokenizer_kwargs = { |
|
"padding": "max_length", |
|
"truncation": True, |
|
"return_tensors": "pt", |
|
"padding_side": "left" |
|
} |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hanyinwang/layer-project-diagnostic-mistral", **tokenizer_kwargs) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
generation_kwargs = { |
|
"min_length": -1, |
|
"top_k": 40, |
|
"top_p": 0.95, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"max_new_tokens":11, |
|
"temperature":0.1, |
|
"repetition_penalty":1.2 |
|
} |
|
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained("hanyinwang/layer-project-diagnostic-mistral").cuda() |
|
|
|
def format_prompt_mistral(text, condition): |
|
prompt = """<s>[INST]You are a medical doctor specialized in %s diagnosis. |
|
From the provided document, assert if the patient historically and currently has %s. |
|
For each condition, only pick from "YES", "NO", or "MAYBE". And you must follow format without anything further. The results have to be directly parseable with python json.loads(). |
|
Sample output: {"%s": "MAYBE"} |
|
Never output anything beyond the format.[/INST] |
|
Provided document: %s"""%(condition, condition, condition, text) |
|
return prompt |
|
|
|
query_tensors = tokenizer.encode(format_prompt_mistral(<note>, <condition>), return_tensors="pt") |
|
# <note>: clinical note |
|
# <condition>: "cancer" or "diabetes" |
|
prompt_length = query_tensors.shape[1] |
|
|
|
outputs = model.generate(query_tensors.cuda(), **generation_kwargs) |
|
response = tokenizer.decode(outputs[0][prompt_length:]) |
|
``` |