bstraehle commited on
Commit
5355a29
1 Parent(s): 227477e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -15,15 +15,14 @@ system_prompt = "You are a text to SQL query translator. Given a question in Eng
15
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
16
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
17
 
18
- base_model_id = "codellama/CodeLlama-7b-hf"
19
  dataset = "b-mc2/sql-create-context"
20
 
21
  def prompt_model(model_id, system_prompt, user_prompt, schema):
22
  pipe = pipeline("text-generation",
23
  model=model_id,
24
  model_kwargs={"torch_dtype": torch.bfloat16},
25
- device_map="auto",
26
- max_new_tokens=1000)
27
  messages = [
28
  {"role": "system", "content": system_prompt.format(schema=schema)},
29
  {"role": "user", "content": user_prompt},
 
15
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
16
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
17
 
18
+ base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
19
  dataset = "b-mc2/sql-create-context"
20
 
21
  def prompt_model(model_id, system_prompt, user_prompt, schema):
22
  pipe = pipeline("text-generation",
23
  model=model_id,
24
  model_kwargs={"torch_dtype": torch.bfloat16},
25
+ device_map="auto")
 
26
  messages = [
27
  {"role": "system", "content": system_prompt.format(schema=schema)},
28
  {"role": "user", "content": user_prompt},