Update app.py
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
|
|
30 |
dataset = "gretelai/synthetic_text_to_sql"
|
31 |
|
32 |
def prompt_model(model_id, system_prompt, user_prompt):
|
33 |
-
pipe = pipeline("text-generation", model=model_id)
|
34 |
|
35 |
messages = [
|
36 |
{"role": "system", "content": system_prompt},
|
@@ -38,7 +38,7 @@ def prompt_model(model_id, system_prompt, user_prompt):
|
|
38 |
{"role": "assistant", "content": ""}
|
39 |
]
|
40 |
|
41 |
-
output = pipe(messages
|
42 |
|
43 |
return output[0]["generated_text"][-1]["content"]
|
44 |
|
@@ -87,7 +87,7 @@ def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tune
|
|
87 |
return result
|
88 |
|
89 |
demo = gr.Interface(fn=process,
|
90 |
-
inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value =
|
91 |
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
|
92 |
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
|
93 |
gr.Textbox(label = "Schema", value = schema, lines = 2),
|
|
|
30 |
dataset = "gretelai/synthetic_text_to_sql"
|
31 |
|
32 |
def prompt_model(model_id, system_prompt, user_prompt):
|
33 |
+
pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
|
34 |
|
35 |
messages = [
|
36 |
{"role": "system", "content": system_prompt},
|
|
|
38 |
{"role": "assistant", "content": ""}
|
39 |
]
|
40 |
|
41 |
+
output = pipe(messages)
|
42 |
|
43 |
return output[0]["generated_text"][-1]["content"]
|
44 |
|
|
|
87 |
return result
|
88 |
|
89 |
demo = gr.Interface(fn=process,
|
90 |
+
inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_1),
|
91 |
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
|
92 |
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
|
93 |
gr.Textbox(label = "Schema", value = schema, lines = 2),
|