File size: 6,232 Bytes
a70f9f8 ac66ae2 a993443 a70f9f8 ffef239 d6a8f30 e760c8d 74c640a c8534fb e92ef1c df44c11 38576e5 df44c11 467c88a df44c11 361f8dd 0aa11b1 df44c11 a9bd106 76d0fb3 ffef239 76d0fb3 ffef239 87a35cb 3d77c48 08c0eb5 3d77c48 6bb7c86 76d0fb3 93508c3 76d0fb3 4f06478 76d0fb3 1e971a3 e8d1605 76d0fb3 a9bd106 38576e5 03a8827 a9bd106 7f9f34a 613b540 46ea1b4 241fd2c 613b540 ada7179 74c640a cbbb9fd c2cbd84 ada7179 cbbb9fd ada7179 39546c6 0fb434b ada7179 2b03f9f 74c640a ada7179 74c640a 2b03f9f 2371111 835fa92 e6a4e68 13e776a 94ca6da 13e776a 73899fd 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
hf_profile = "bstraehle"
action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
base_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset = "b-mc2/sql-create-context"
def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
#raise gr.Error("Please clone and bring your own credentials.")
if action == action_1:
result = fine_tune_model(base_model_id, dataset)
elif action == action_2:
fine_tuned_model_id = replace_hf_profile(base_model_id)
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
return result
def fine_tune_model(base_model_id, dataset):
# tokenizer = download_model(base_model_id)
# fine_tuned_model_id = upload_model(base_model_id, tokenizer)
# return fine_tuned_model_id
# Load the dataset
dataset = load_dataset("gretelai/synthetic_text_to_sql")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load pre-trained model and tokenizer
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Preprocess the dataset
def preprocess(examples):
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, truncation=True)
return model_inputs
dataset = dataset.map(preprocess, batched=True)
# Split dataset to training and validation sets
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) # Adjust the range as needed
val_dataset = dataset["test"].shuffle(seed=42).select(range(100)) # Adjust the range as needed
# Set training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
num_train_epochs=3, # Adjust as needed
per_device_train_batch_size=8,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
save_total_limit=2,
save_steps=500,
eval_steps=500,
metric_for_best_model="accuracy",
greater_is_better=True,
save_on_each_node=True,
load_best_model_at_end=True,
eval_strategy="steps",
gradient_checkpointing=True,
)
# Create Trainer instance
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
)
# Train the model
trainer.train()
# Save the trained model
trainer.save_model("./fine_tuned_model")
# Create a repository object
repo = Repository(
local_dir="./fine_tuned_model",
repo_type="model",
repo_id="bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql",
)
# Login to the Hugging Face hub
repo.login(token=os.environ["HF_TOKEN"])
# Push the model to the hub
repo.push_to_hub(commit_message="Initial commit")
def prompt_model(model_id, system_prompt, user_prompt, schema):
pipe = pipeline("text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
max_new_tokens=1000)
messages = [
{"role": "system", "content": system_prompt.format(schema=schema)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
result = output[0]["generated_text"][-1]["content"]
print(result)
return result
def download_model(base_model_id):
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)
model.save_pretrained(base_model_id)
return tokenizer
def upload_model(base_model_id, tokenizer):
fine_tuned_model_id = replace_hf_profile(base_model_id)
login(token=os.environ["HF_TOKEN"])
api = HfApi()
#api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
api.create_repo(repo_id=fine_tuned_model_id)
api.upload_folder(
folder_path=base_model_id,
repo_id=fine_tuned_model_id
)
tokenizer.push_to_hub(fine_tuned_model_id)
return fine_tuned_model_id
def replace_hf_profile(base_model_id):
model_id = base_model_id[base_model_id.rfind('/')+1:]
return f"{hf_profile}/{model_id}"
demo = gr.Interface(fn=process,
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
gr.Textbox(label = "Dataset", value = dataset, lines = 1),
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
gr.Textbox(label = "Schema", value = schema, lines = 2)],
outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
demo.launch() |