File size: 6,211 Bytes
c713850 a28b4b8 ac66ae2 a993443 a70f9f8 256580a 991d6c0 9697caf e760c8d 2dfbd8a c8534fb 3d0bfc5 da6722c df44c11 2dfbd8a 3d0bfc5 2dfbd8a df44c11 da6722c a9bd106 5b45741 da6722c 5b45741 2dfbd8a 5b45741 da6722c a9bd106 75f5c42 88543e6 5e0038e 092da5d 9cb6b16 092da5d 88543e6 5e0038e 092da5d ffef239 092da5d 7be2c23 88543e6 8cdd9a7 3d77c48 c06669a 3d77c48 8cdd9a7 092da5d 93508c3 092da5d 88543e6 53b729b 5e0038e 9d8f256 53b729b 88543e6 092da5d 53b729b 092da5d 88543e6 1939ff5 9697caf 3d0bfc5 9b16331 1939ff5 53b729b 76d0fb3 88543e6 bf28b8c 88543e6 1939ff5 9697caf a871fa1 bb99aa8 53b729b bb99aa8 76d0fb3 8f45dd8 5e0038e 7826053 5e0038e 3d0bfc5 256580a 7826053 b85865d da6722c 03a8827 88543e6 03a8827 88543e6 03a8827 da6722c 03a8827 88543e6 03a8827 88543e6 03a8827 7f9f34a 88543e6 613b540 88543e6 2b03f9f 88543e6 96edae0 1afcb19 10e80e0 9697caf 17cccab 991d6c0 17cccab 991d6c0 8c90204 991d6c0 17cccab 991d6c0 160048e 1fca62f 88543e6 2371111 b4de4c9 2dfbd8a 3d0bfc5 2dfbd8a cf51f99 88543e6 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Run on Google TPU v5e 2x4 or equivalent (220 vCPU, 380 GB RAM, 128 GB VRAM)
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from peft import LoraConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
ACTION_1 = "Prompt base model"
ACTION_2 = "Fine-tune base model"
ACTION_3 = "Prompt fine-tuned model"
HF_ACCOUNT = "bstraehle"
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
FT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct-text-to-sql"
DATASET_NAME = "gretelai/synthetic_text_to_sql"
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
#raise gr.Error("Please clone and bring your own credentials.")
if action == ACTION_1:
result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
elif action == ACTION_2:
result = fine_tune_model(base_model_name, dataset_name)
elif action == ACTION_3:
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
return result
def fine_tune_model(base_model_name, dataset_name):
# Load dataset
dataset = load_dataset(dataset_name)
print("### Dataset")
print(dataset)
print("### Example")
print(dataset["train"][:1])
print("###")
# Load model
model, tokenizer = load_model(base_model_name)
print("### Model")
print(model)
print("### Tokenizer")
print(tokenizer)
print("###")
# Pre-process dataset
def preprocess(examples):
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
return model_inputs
dataset = dataset.map(preprocess, batched=True)
print("### Pre-processed dataset")
print(dataset)
print("### Example")
print(dataset["train"][:1])
print("###")
# Split dataset into training and evaluation sets
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
print("### Training dataset")
print(train_dataset)
print("### Evaluation dataset")
print(eval_dataset)
print("###")
# Configure training arguments
training_args = Seq2SeqTrainingArguments(
output_dir=f"./{FT_MODEL_NAME}",
num_train_epochs=3, # 37,500 steps
max_steps=1, # overwrites num_train_epochs
push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
)
print("### Training arguments")
print(training_args)
print("###")
# Create trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
)
# Train model
trainer.train()
# Push tokenizer to HF
tokenizer.push_to_hub(FT_MODEL_NAME)
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
pipe = pipeline("text-generation",
model=model_name,
device_map="auto",
max_new_tokens=1000)
messages = [
{"role": "system", "content": system_prompt.format(sql_context=sql_context)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
result = output[0]["generated_text"][-1]["content"]
print("###")
print(result)
print("###")
return result
def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# PEFT, LoRA, QLoRA, see https://huggingface.co/blog/mlabonne/sft-llama3
###
print("111")
peft_config = LoraConfig(
r=64,
# TODO
#bias="none",
#lora_alpha=16,
#lora_dropout=0,
#task_type="CAUSAL_LM",
)
print("222")
model = PeftModel.from_pretrained(model, "new_model", peft_config=peft_config)
print("333")
model = model.merge_and_unload()
print("444")
model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
print("555")
###
return model, tokenizer
demo = gr.Interface(fn=process,
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
demo.launch() |