sft / app.py
bstraehle's picture
Update app.py
ba42713 verified
raw
history blame
5.97 kB
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
ACTION_1 = "Prompt base model"
ACTION_2 = "Fine-tune base model"
ACTION_3 = "Prompt fine-tuned model"
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
DATASET_NAME = "gretelai/synthetic_text_to_sql"
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
#raise gr.Error("Please clone and bring your own credentials.")
if action == ACTION_1:
result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
elif action == ACTION_2:
result = fine_tune_model(base_model_name, dataset_name)
elif action == ACTION_3:
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
return result
def fine_tune_model(base_model_name, dataset_name):
# Load dataset
dataset = load_dataset(dataset_name)
print("### Dataset")
print(dataset)
print("### Example")
print(dataset["train"][:1])
print("###")
# Load model
model, tokenizer = load_model(base_model_name)
print("### Model")
print(model)
print("### Tokenizer")
print(tokenizer)
print("###")
# Pre-process dataset
def preprocess(examples):
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
return model_inputs
dataset = dataset.map(preprocess, batched=True)
print("### Pre-processed dataset")
print(dataset)
print("### Example")
print(dataset["train"][:1])
print("###")
# Split dataset into training and validation sets
#train_dataset = dataset["train"]
#test_dataset = dataset["test"]
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
print("### Training dataset")
print(train_dataset)
print("### Validation dataset")
print(test_dataset)
print("###")
# Configure training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./output",
logging_dir="./logging",
num_train_epochs=1,
max_steps=1, # overwrites num_train_epochs
#push_to_hub=True,
#per_device_train_batch_size=16,
#per_device_eval_batch_size=64,
#eval_strategy="steps",
#save_total_limit=2,
#save_steps=500,
#eval_steps=500,
#warmup_steps=500,
#weight_decay=0.01,
#metric_for_best_model="accuracy",
#greater_is_better=True,
#load_best_model_at_end=True,
#save_on_each_node=True,
)
print("### Training arguments")
print(training_args)
print("###")
# Create trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
#compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
)
# Train model
#trainer.train()
# Save model to HF
api = HfApi()
api.upload_file(
path_or_file="./output",
repo_id=FT_MODEL_NAME,
repo_type="model",
use_auth_token=True,
)
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
pipe = pipeline("text-generation",
model=model_name,
device_map="auto",
max_new_tokens=1000)
messages = [
{"role": "system", "content": system_prompt.format(sql_context=sql_context)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
result = output[0]["generated_text"][-1]["content"]
print("###")
print(result)
print("###")
return result
def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
demo = gr.Interface(fn=process,
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 2)],
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
demo.launch()