bstraehle commited on
Commit
76019db
1 Parent(s): ad99c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -30,7 +30,7 @@ fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
30
  dataset = "gretelai/synthetic_text_to_sql"
31
 
32
  def prompt_model(model_id, system_prompt, user_prompt):
33
- pipe = pipeline("text-generation", model=model_id)
34
 
35
  messages = [
36
  {"role": "system", "content": system_prompt},
@@ -38,7 +38,7 @@ def prompt_model(model_id, system_prompt, user_prompt):
38
  {"role": "assistant", "content": ""}
39
  ]
40
 
41
- output = pipe(messages, model_kwargs={"torch_dtype": torch.bfloat16}, device="cuda")
42
 
43
  return output[0]["generated_text"][-1]["content"]
44
 
@@ -87,7 +87,7 @@ def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tune
87
  return result
88
 
89
  demo = gr.Interface(fn=process,
90
- inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_2),
91
  gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
92
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
93
  gr.Textbox(label = "Schema", value = schema, lines = 2),
 
30
  dataset = "gretelai/synthetic_text_to_sql"
31
 
32
  def prompt_model(model_id, system_prompt, user_prompt):
33
+ pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
34
 
35
  messages = [
36
  {"role": "system", "content": system_prompt},
 
38
  {"role": "assistant", "content": ""}
39
  ]
40
 
41
+ output = pipe(messages)
42
 
43
  return output[0]["generated_text"][-1]["content"]
44
 
 
87
  return result
88
 
89
  demo = gr.Interface(fn=process,
90
+ inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_1),
91
  gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
92
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
93
  gr.Textbox(label = "Schema", value = schema, lines = 2),