bstraehle commited on
Commit
cd72e2e
1 Parent(s): 96210aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -22,8 +22,10 @@ action_2 = "Prompt fine-tuned model"
22
  action_3 = "Fine-tune base model"
23
 
24
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
25
- system_prompt = """You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text.
26
- SCHEMA: """ + schema
 
 
27
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
28
 
29
  base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -77,7 +79,7 @@ def upload_model(model_id, tokenizer):
77
 
78
  return model_repo_name
79
 
80
- def process(action, system_prompt, user_prompt, base_model_id, fine_tuned_model_id, dataset):
81
  if action == action_1:
82
  result = prompt_model(base_model_id, system_prompt, user_prompt)
83
  elif action == action_2:
@@ -91,6 +93,7 @@ demo = gr.Interface(fn=process,
91
  inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_2),
92
  gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
93
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
 
94
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
95
  gr.Textbox(label = "Fine-Tuned Model ID", value = fine_tuned_model_id, lines = 1),
96
  gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
 
22
  action_3 = "Fine-tune base model"
23
 
24
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
25
+ system_prompt = """You are a text to SQL query translator.
26
+ Given a question in English, generate a SQL query based on the provided SCHEMA.
27
+ Do not generate any additional text.
28
+ SCHEMA: """
29
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
30
 
31
  base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
79
 
80
  return model_repo_name
81
 
82
+ def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
83
  if action == action_1:
84
  result = prompt_model(base_model_id, system_prompt, user_prompt)
85
  elif action == action_2:
 
93
  inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_2),
94
  gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
95
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
96
+ gr.Textbox(label = "Schema", value = schema, lines = 2),
97
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
98
  gr.Textbox(label = "Fine-Tuned Model ID", value = fine_tuned_model_id, lines = 1),
99
  gr.Textbox(label = "Dataset", value = dataset, lines = 1)],