hanyinwang's picture
Update README.md
7df05eb verified
---
license: apache-2.0
tags:
- trl
- ppo
- transformers
- reinforcement-learning
language:
- en
---
# TRL Model
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
guide the model outputs according to a simulated human feedback. The model was fine-tuned for classification of cancer / diabetes based on clinical notes.
```bash
pip install torch transformers trl peft
```
## Usage
```python
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from peft import LoraConfig
tokenizer_kwargs = {
"padding": "max_length",
"truncation": True,
"return_tensors": "pt",
"padding_side": "left"
}
tokenizer = AutoTokenizer.from_pretrained("hanyinwang/layer-project-diagnostic-mistral", **tokenizer_kwargs)
tokenizer.pad_token = tokenizer.eos_token
generation_kwargs = {
"min_length": -1,
"top_k": 40,
"top_p": 0.95,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens":11,
"temperature":0.1,
"repetition_penalty":1.2
}
model = AutoModelForCausalLMWithValueHead.from_pretrained("hanyinwang/layer-project-diagnostic-mistral").cuda()
def format_prompt_mistral(text, condition):
prompt = """<s>[INST]You are a medical doctor specialized in %s diagnosis.
From the provided document, assert if the patient historically and currently has %s.
For each condition, only pick from "YES", "NO", or "MAYBE". And you must follow format without anything further. The results have to be directly parseable with python json.loads().
Sample output: {"%s": "MAYBE"}
Never output anything beyond the format.[/INST]
Provided document: %s"""%(condition, condition, condition, text)
return prompt
query_tensors = tokenizer.encode(format_prompt_mistral(<note>, <condition>), return_tensors="pt")
# <note>: clinical note
# <condition>: "cancer" or "diabetes"
prompt_length = query_tensors.shape[1]
outputs = model.generate(query_tensors.cuda(), **generation_kwargs)
response = tokenizer.decode(outputs[0][prompt_length:])
```