File size: 3,830 Bytes
a70f9f8 ac66ae2 6137b7d a70f9f8 d6a8f30 206ee23 d6a8f30 74c640a c8534fb e92ef1c df44c11 38576e5 df44c11 467c88a df44c11 a70f9f8 0aa11b1 df44c11 38576e5 03a8827 a70f9f8 df44c11 6dd3828 a70f9f8 7f9f34a 613b540 ada7179 74c640a cbbb9fd 5eaca58 ada7179 cbbb9fd ada7179 39546c6 0fb434b ada7179 2b03f9f 74c640a ada7179 74c640a 2b03f9f e69ea59 e566aba df44c11 227477e e92ef1c 74c640a 065dd39 a184b8b 2371111 835fa92 e6a4e68 13e776a 94ca6da 13e776a 73899fd 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
import gradio as gr
import os#, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
#from peft import AutoPeftModelForCausalLM, LoraConfig
#from random import randint
from transformers import AutoTokenizer, AutoModelForCausalLM#, BitsAndBytesConfig, TrainingArguments, pipeline
#from trl import SFTTrainer, setup_chat_format
hf_profile = "bstraehle"
action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
base_model_id = "codellama/CodeLlama-7b-hf"
dataset = "b-mc2/sql-create-context"
def prompt_model(model_id, system_prompt, user_prompt, schema):
pipe = pipeline("text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
max_new_tokens=1000)
messages = [
{"role": "system", "content": system_prompt.format(schema=schema)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
result = output[0]["generated_text"][-1]["content"]
print(result)
return result
def fine_tune_model(base_model_id, dataset):
tokenizer = download_model(base_model_id)
fine_tuned_model_id = upload_model(base_model_id, tokenizer)
return fine_tuned_model_id
def download_model(base_model_id):
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)
model.save_pretrained(base_model_id)
return tokenizer
def upload_model(base_model_id, tokenizer):
fine_tuned_model_id = replace_hf_profile(base_model_id)
login(token=os.environ["HF_TOKEN"])
api = HfApi()
#api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
api.create_repo(repo_id=fine_tuned_model_id)
api.upload_folder(
folder_path=base_model_id,
repo_id=fine_tuned_model_id
)
tokenizer.push_to_hub(fine_tuned_model_id)
return fine_tuned_model_id
def replace_hf_profile(base_model_id):
model_id = base_model_id[base_model_id.rfind('/')+1:]
return f"{hf_profile}/{model_id}"
def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
#raise gr.Error("Please clone and bring your own credentials.")
if action == action_1:
result = fine_tune_model(base_model_id, dataset)
elif action == action_2:
fine_tuned_model_id = replace_hf_profile(base_model_id)
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
return result
demo = gr.Interface(fn=process,
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
gr.Textbox(label = "Dataset", value = dataset, lines = 1),
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
gr.Textbox(label = "Schema", value = schema, lines = 2)],
outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
demo.launch() |