gardner's picture
Upload folder using huggingface_hub
c475ecf verified
raw
history blame
No virus
2.26 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import argilla as rg
import os
from tqdm import tqdm
dir_path = os.path.dirname(os.path.realpath(__file__))
rg.init(
api_url="<argilla-api-url>",
api_key="<argilla-api-key>"
)
ds = rg.FeedbackDataset.for_preference_modeling(
number_of_responses=2,
context=False,
use_markdown=False,
guidelines=None,
metadata_properties=None,
vectors_settings=None,
)
model_dir = dir_path +'/lora-out/merged'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
dataset = load_dataset("gardner/glaive-function-calling-v2-sharegpt", split="test")
def preprocess_function_calling(examples):
texts = []
answers = []
for chat in examples["conversations"]:
for turn in chat:
# from -> role
turn['role'] = turn.pop('from')
turn['content'] = turn.pop('value')
if turn['role'] == 'human':
turn['role'] = 'user'
if turn['role'] == 'gpt':
turn['role'] = 'assistant'
if chat[-1]['role'] == 'assistant':
answers.append(chat[-1]['content'])
del chat[-1] # remove the last assistant turn
texts.append(tokenizer.apply_chat_template(chat, tokenize=False))
return { "texts": texts, "answer": answers }
texts = dataset.map(preprocess_function_calling, batched=True)
model = AutoModelForCausalLM.from_pretrained(model_dir).to("cuda")
records = []
for text in tqdm(texts):
prompt = text['texts'] + "<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
response = tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0.9)
prompt_len = len(text['texts'] + "<|im_start|>assistant\n")
response1 = response[prompt_len:].replace("<|im_end|>\n", "").strip()
response2 = text['answer']
print(response1)
records.append(rg.FeedbackRecord(fields={
"prompt": prompt,
"response1": response1,
"response2": response2,
}))
# text['response'] = response
ds.add_records(records)
ds.push_to_argilla(name="function-calling", workspace="argilla")