Update app.py
Browse files
app.py
CHANGED
@@ -16,22 +16,22 @@ SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in Eng
|
|
16 |
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
|
17 |
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
|
18 |
|
19 |
-
|
20 |
FT_MODEL_NAME = "Meta-Llama-3.1-8B-text-to-sql"
|
21 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
22 |
|
23 |
-
def process(action,
|
24 |
raise gr.Error("Please clone and bring your own Hugging Face credentials.")
|
25 |
|
26 |
if action == ACTION_1:
|
27 |
-
result = prompt_model(
|
28 |
elif action == ACTION_2:
|
29 |
-
result = fine_tune_model(
|
30 |
elif action == ACTION_3:
|
31 |
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
|
32 |
return result
|
33 |
|
34 |
-
def fine_tune_model(
|
35 |
# Load dataset
|
36 |
|
37 |
dataset = load_dataset(dataset_name)
|
@@ -44,7 +44,7 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
44 |
|
45 |
# Load model
|
46 |
|
47 |
-
model, tokenizer = load_model(
|
48 |
|
49 |
print("### Model")
|
50 |
print(model)
|
@@ -80,7 +80,7 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
80 |
# Configure training arguments
|
81 |
|
82 |
training_args = Seq2SeqTrainingArguments(
|
83 |
-
output_dir=f"./{
|
84 |
num_train_epochs=3, # 37,500 steps
|
85 |
#max_steps=1, # overwrites num_train_epochs
|
86 |
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
|
@@ -106,8 +106,8 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
106 |
|
107 |
# Push model and tokenizer to HF
|
108 |
|
109 |
-
model.push_to_hub(
|
110 |
-
tokenizer.push_to_hub(
|
111 |
|
112 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
113 |
pipe = pipeline("text-generation",
|
@@ -142,9 +142,9 @@ def load_model(model_name):
|
|
142 |
|
143 |
demo = gr.Interface(fn=process,
|
144 |
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
|
145 |
-
gr.Textbox(label = "
|
146 |
-
gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
|
147 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
|
|
148 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
149 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
150 |
gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],
|
|
|
16 |
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
|
17 |
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
|
18 |
|
19 |
+
PT_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
|
20 |
FT_MODEL_NAME = "Meta-Llama-3.1-8B-text-to-sql"
|
21 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
22 |
|
23 |
+
def process(action, pt_model_name, dataset_name, ft_model_name, system_prompt, user_prompt, sql_context):
|
24 |
raise gr.Error("Please clone and bring your own Hugging Face credentials.")
|
25 |
|
26 |
if action == ACTION_1:
|
27 |
+
result = prompt_model(pt_model_name, system_prompt, user_prompt, sql_context)
|
28 |
elif action == ACTION_2:
|
29 |
+
result = fine_tune_model(pt_model_name, dataset_name, ft_model_name)
|
30 |
elif action == ACTION_3:
|
31 |
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
|
32 |
return result
|
33 |
|
34 |
+
def fine_tune_model(pt_model_name, dataset_name, ft_model_name):
|
35 |
# Load dataset
|
36 |
|
37 |
dataset = load_dataset(dataset_name)
|
|
|
44 |
|
45 |
# Load model
|
46 |
|
47 |
+
model, tokenizer = load_model(pt_model_name)
|
48 |
|
49 |
print("### Model")
|
50 |
print(model)
|
|
|
80 |
# Configure training arguments
|
81 |
|
82 |
training_args = Seq2SeqTrainingArguments(
|
83 |
+
output_dir=f"./{ft_model_name}",
|
84 |
num_train_epochs=3, # 37,500 steps
|
85 |
#max_steps=1, # overwrites num_train_epochs
|
86 |
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
|
|
|
106 |
|
107 |
# Push model and tokenizer to HF
|
108 |
|
109 |
+
model.push_to_hub(ft_model_name)
|
110 |
+
tokenizer.push_to_hub(ft_model_name)
|
111 |
|
112 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
113 |
pipe = pipeline("text-generation",
|
|
|
142 |
|
143 |
demo = gr.Interface(fn=process,
|
144 |
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
|
145 |
+
gr.Textbox(label = "Pre-Trained Model Name", value = PT_MODEL_NAME, lines = 1),
|
|
|
146 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
147 |
+
gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
|
148 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
149 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
150 |
gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],
|