Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ action_1 = "Prompt base model"
|
|
21 |
action_2 = "Prompt fine-tuned model"
|
22 |
action_3 = "Fine-tune base model"
|
23 |
|
24 |
-
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: "
|
25 |
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
|
26 |
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
|
27 |
|
@@ -29,15 +29,17 @@ base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
29 |
fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
|
30 |
dataset = "gretelai/synthetic_text_to_sql"
|
31 |
|
32 |
-
def prompt_model(model_id, system_prompt, user_prompt):
|
33 |
pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
|
34 |
|
35 |
messages = [
|
36 |
-
{"role": "system", "content": system_prompt},
|
37 |
{"role": "user", "content": user_prompt},
|
38 |
{"role": "assistant", "content": ""}
|
39 |
]
|
40 |
|
|
|
|
|
41 |
output = pipe(messages)
|
42 |
|
43 |
return output[0]["generated_text"][-1]["content"]
|
@@ -78,9 +80,9 @@ def upload_model(model_id, tokenizer):
|
|
78 |
|
79 |
def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
|
80 |
if action == action_1:
|
81 |
-
result = prompt_model(base_model_id, system_prompt, user_prompt)
|
82 |
elif action == action_2:
|
83 |
-
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt)
|
84 |
elif action == action_3:
|
85 |
result = fine_tune_model(base_model_id)
|
86 |
|
|
|
21 |
action_2 = "Prompt fine-tuned model"
|
22 |
action_3 = "Fine-tune base model"
|
23 |
|
24 |
+
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
|
25 |
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
|
26 |
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
|
27 |
|
|
|
29 |
fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
|
30 |
dataset = "gretelai/synthetic_text_to_sql"
|
31 |
|
32 |
+
def prompt_model(model_id, system_prompt, user_prompt, schema):
|
33 |
pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
|
34 |
|
35 |
messages = [
|
36 |
+
{"role": "system", "content": system_prompt.format(schema=schema)},
|
37 |
{"role": "user", "content": user_prompt},
|
38 |
{"role": "assistant", "content": ""}
|
39 |
]
|
40 |
|
41 |
+
print(messages)
|
42 |
+
|
43 |
output = pipe(messages)
|
44 |
|
45 |
return output[0]["generated_text"][-1]["content"]
|
|
|
80 |
|
81 |
def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
|
82 |
if action == action_1:
|
83 |
+
result = prompt_model(base_model_id, system_prompt, user_prompt, schema)
|
84 |
elif action == action_2:
|
85 |
+
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
|
86 |
elif action == action_3:
|
87 |
result = fine_tune_model(base_model_id)
|
88 |
|