bstraehle commited on
Commit
38576e5
1 Parent(s): a763c52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -21,7 +21,7 @@ action_1 = "Prompt base model"
21
  action_2 = "Prompt fine-tuned model"
22
  action_3 = "Fine-tune base model"
23
 
24
- system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: "
25
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
26
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
27
 
@@ -29,15 +29,17 @@ base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
29
  fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
30
  dataset = "gretelai/synthetic_text_to_sql"
31
 
32
- def prompt_model(model_id, system_prompt, user_prompt):
33
  pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
34
 
35
  messages = [
36
- {"role": "system", "content": system_prompt},
37
  {"role": "user", "content": user_prompt},
38
  {"role": "assistant", "content": ""}
39
  ]
40
 
 
 
41
  output = pipe(messages)
42
 
43
  return output[0]["generated_text"][-1]["content"]
@@ -78,9 +80,9 @@ def upload_model(model_id, tokenizer):
78
 
79
  def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
80
  if action == action_1:
81
- result = prompt_model(base_model_id, system_prompt, user_prompt)
82
  elif action == action_2:
83
- result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt)
84
  elif action == action_3:
85
  result = fine_tune_model(base_model_id)
86
 
 
21
  action_2 = "Prompt fine-tuned model"
22
  action_3 = "Fine-tune base model"
23
 
24
+ system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
25
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
26
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
27
 
 
29
  fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
30
  dataset = "gretelai/synthetic_text_to_sql"
31
 
32
+ def prompt_model(model_id, system_prompt, user_prompt, schema):
33
  pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
34
 
35
  messages = [
36
+ {"role": "system", "content": system_prompt.format(schema=schema)},
37
  {"role": "user", "content": user_prompt},
38
  {"role": "assistant", "content": ""}
39
  ]
40
 
41
+ print(messages)
42
+
43
  output = pipe(messages)
44
 
45
  return output[0]["generated_text"][-1]["content"]
 
80
 
81
  def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
82
  if action == action_1:
83
+ result = prompt_model(base_model_id, system_prompt, user_prompt, schema)
84
  elif action == action_2:
85
+ result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
86
  elif action == action_3:
87
  result = fine_tune_model(base_model_id)
88