bstraehle commited on
Commit
ed270e5
1 Parent(s): 6dd3828

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -15,8 +15,8 @@ system_prompt = "You are a text to SQL query translator. Given a question in Eng
15
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
16
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
17
 
18
- base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
19
- dataset = "b-mc2/sql-create-context" #"gretelai/synthetic_text_to_sql"
20
 
21
  def prompt_model(model_id, system_prompt, user_prompt, schema):
22
  pipe = pipeline("text-generation",
@@ -35,10 +35,10 @@ def prompt_model(model_id, system_prompt, user_prompt, schema):
35
  return result
36
 
37
  def fine_tune_model(base_model_id, dataset):
38
- #tokenizer = download_model(base_model_id)
39
- download_dataset(dataset)
40
- #fine_tuned_model_id = upload_model(base_model_id, tokenizer)
41
- return "Done" #fine_tuned_model_id
42
 
43
  def download_model(base_model_id):
44
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
 
15
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
16
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
17
 
18
+ base_model_id = "codellama/CodeLlama-7b-hf"
19
+ dataset = "b-mc2/sql-create-context"
20
 
21
  def prompt_model(model_id, system_prompt, user_prompt, schema):
22
  pipe = pipeline("text-generation",
 
35
  return result
36
 
37
  def fine_tune_model(base_model_id, dataset):
38
+ tokenizer = download_model(base_model_id)
39
+ #download_dataset(dataset)
40
+ fine_tuned_model_id = upload_model(base_model_id, tokenizer)
41
+ return fine_tuned_model_id
42
 
43
  def download_model(base_model_id):
44
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)