bstraehle commited on
Commit
e92ef1c
1 Parent(s): 065dd39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -8,7 +8,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
 
9
  # Model IDs:
10
  #
11
- # google/gemma-2-9b-it
12
  # meta-llama/Meta-Llama-3-8B-Instruct
13
 
14
  # Datasets:
@@ -17,9 +16,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
 
18
  profile = "bstraehle"
19
 
20
- action_1 = "Prompt base model"
21
- action_2 = "Fine-tune base model"
22
- action_3 = "Prompt fine-tuned model"
23
 
24
  system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
25
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
@@ -76,12 +74,12 @@ def upload_model(model_id, tokenizer):
76
 
77
  return model_repo_name
78
 
79
- def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
80
  if action == action_1:
81
- result = prompt_model(base_model_id, system_prompt, user_prompt, schema)
82
- elif action == action_2:
83
  result = fine_tune_model(base_model_id)
84
- elif action == action_3:
 
 
85
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
86
 
87
  return result
@@ -92,7 +90,6 @@ demo = gr.Interface(fn=process,
92
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
93
  gr.Textbox(label = "Schema", value = schema, lines = 2),
94
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
95
- gr.Textbox(label = "Fine-Tuned Model ID", value = fine_tuned_model_id, lines = 1),
96
  gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
97
  outputs=[gr.Textbox(label = "Completion")])
98
  demo.launch()
 
8
 
9
  # Model IDs:
10
  #
 
11
  # meta-llama/Meta-Llama-3-8B-Instruct
12
 
13
  # Datasets:
 
16
 
17
  profile = "bstraehle"
18
 
19
+ action_1 = "Fine-tune pre-trained model"
20
+ action_2 = "Prompt fine-tuned model"
 
21
 
22
  system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
23
  user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
 
74
 
75
  return model_repo_name
76
 
77
+ def process(action, system_prompt, user_prompt, schema, base_model_id, dataset):
78
  if action == action_1:
 
 
79
  result = fine_tune_model(base_model_id)
80
+ elif action == action_2:
81
+ model_id = base_model_id[base_model_id.rfind('/')+1:]
82
+ fine_tuned_model_id = f"{profile}/{model_id}"
83
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
84
 
85
  return result
 
90
  gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
91
  gr.Textbox(label = "Schema", value = schema, lines = 2),
92
  gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
 
93
  gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
94
  outputs=[gr.Textbox(label = "Completion")])
95
  demo.launch()