Spaces:
Sleeping
Sleeping
arslan-ahmed
commited on
Commit
•
bde0da6
1
Parent(s):
12314a0
updated BAM models
Browse files- app.py +11 -9
- ttyd_consts.py +3 -4
- ttyd_functions.py +3 -0
app.py
CHANGED
@@ -20,6 +20,7 @@ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenP
|
|
20 |
from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
|
21 |
from ibm_watson_machine_learning.foundation_models import Model
|
22 |
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
|
|
|
23 |
|
24 |
import genai
|
25 |
|
@@ -77,12 +78,13 @@ def setOaiApiKey(creds):
|
|
77 |
def setBamApiKey(creds):
|
78 |
creds = getBamCreds(creds)
|
79 |
try:
|
80 |
-
genai.Model.models(credentials=creds['bam_creds'])
|
|
|
81 |
api_key_st = creds
|
82 |
-
return 'BAM credentials accepted.', *[x.update(interactive=False) for x in credComps_btn_tb], api_key_st
|
83 |
except Exception as e:
|
84 |
gr.Warning(str(e))
|
85 |
-
return [x.update() for x in credComps_op]
|
86 |
|
87 |
def setWxApiKey(key, p_id):
|
88 |
creds = getWxCreds(key, p_id)
|
@@ -97,7 +99,7 @@ def setWxApiKey(key, p_id):
|
|
97 |
|
98 |
# convert user uploaded data to vectorstore
|
99 |
def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.Progress()):
|
100 |
-
opComponents = [data_ingest_btn, upload_fb, urls_tb]
|
101 |
# parse user data
|
102 |
file_paths = []
|
103 |
documents = []
|
@@ -129,7 +131,7 @@ def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.P
|
|
129 |
src_str = str(src_str[1]) + ' source document(s) successfully loaded in vector store.'+'\n\n' + src_str[0]
|
130 |
|
131 |
progress(1, 'Data loaded')
|
132 |
-
return vsDict_st, src_str, *[x.update(interactive=False) for x in [data_ingest_btn, upload_fb]], urls_tb.update(interactive=False, placeholder='')
|
133 |
|
134 |
# initialize chatbot function sets the QA Chain, and also sets/updates any other components to start chatting. updateQaChain function only updates QA chain and will be called whenever Adv Settings are updated.
|
135 |
def initializeChatbot(temp, k, modelNameDD, stdlQs, api_key_st, vsDict_st, progress=gr.Progress()):
|
@@ -247,7 +249,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
|
|
247 |
, placeholder=url_tb_ph)
|
248 |
data_ingest_btn = gr.Button("Load Data")
|
249 |
status_tb = gr.TextArea(label='Status Info')
|
250 |
-
initChatbot_btn = gr.Button("Initialize Chatbot", variant="primary")
|
251 |
|
252 |
credComps_btn_tb = [oaiKey_tb, oaiKey_btn, bamKey_tb, bamKey_btn, wxKey_tb, wxPid_tb, wxKey_btn]
|
253 |
credComps_op = [status_tb] + credComps_btn_tb + [api_key_state]
|
@@ -266,7 +268,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
|
|
266 |
temp_sld = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", info='Sampling temperature to use when calling LLM. Defaults to 0.7')
|
267 |
k_sld = gr.Slider(minimum=1, maximum=10, step=1, value=mode.k, label="K", info='Number of relavant documents to return from Vector Store. Defaults to 4')
|
268 |
model_dd = gr.Dropdown(label='Model Name'\
|
269 |
-
, choices=
|
270 |
, info=model_dd_info)
|
271 |
stdlQs_rb = gr.Radio(label='Standalone Question', info=stdlQs_rb_info\
|
272 |
, type='index', value=stdlQs_rb_choices[1]\
|
@@ -280,7 +282,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
|
|
280 |
oaiKey_tb.submit(**oaiKey_btn_args)
|
281 |
|
282 |
# BAM API button
|
283 |
-
bamKey_btn_args = {'fn':setBamApiKey, 'inputs':[bamKey_tb], 'outputs':credComps_op}
|
284 |
bamKey_btn.click(**bamKey_btn_args)
|
285 |
bamKey_tb.submit(**bamKey_btn_args)
|
286 |
|
@@ -289,7 +291,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
|
|
289 |
wxKey_btn.click(**wxKey_btn_args)
|
290 |
|
291 |
# Data Ingest Button
|
292 |
-
data_ingest_event = data_ingest_btn.click(uiData_vecStore, [upload_fb, urls_tb, api_key_state, chromaVS_state], [chromaVS_state, status_tb, data_ingest_btn, upload_fb, urls_tb])
|
293 |
|
294 |
# Adv Settings
|
295 |
advSet_args = {'fn':updateQaChain, 'inputs':[temp_sld, k_sld, model_dd, stdlQs_rb, api_key_state, chromaVS_state], 'outputs':[qa_state, model_dd]}
|
|
|
20 |
from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
|
21 |
from ibm_watson_machine_learning.foundation_models import Model
|
22 |
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
|
23 |
+
from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes
|
24 |
|
25 |
import genai
|
26 |
|
|
|
78 |
def setBamApiKey(creds):
|
79 |
creds = getBamCreds(creds)
|
80 |
try:
|
81 |
+
bam_models = genai.Model.models(credentials=creds['bam_creds'])
|
82 |
+
bam_models = sorted(x.id for x in bam_models)
|
83 |
api_key_st = creds
|
84 |
+
return 'BAM credentials accepted.', *[x.update(interactive=False) for x in credComps_btn_tb], api_key_st, model_dd.update(choices=getModelChoices(openAi_models, ModelTypes, bam_models))
|
85 |
except Exception as e:
|
86 |
gr.Warning(str(e))
|
87 |
+
return *[x.update() for x in credComps_op], model_dd.update()
|
88 |
|
89 |
def setWxApiKey(key, p_id):
|
90 |
creds = getWxCreds(key, p_id)
|
|
|
99 |
|
100 |
# convert user uploaded data to vectorstore
|
101 |
def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.Progress()):
|
102 |
+
opComponents = [data_ingest_btn, upload_fb, urls_tb, initChatbot_btn]
|
103 |
# parse user data
|
104 |
file_paths = []
|
105 |
documents = []
|
|
|
131 |
src_str = str(src_str[1]) + ' source document(s) successfully loaded in vector store.'+'\n\n' + src_str[0]
|
132 |
|
133 |
progress(1, 'Data loaded')
|
134 |
+
return vsDict_st, src_str, *[x.update(interactive=False) for x in [data_ingest_btn, upload_fb]], urls_tb.update(interactive=False, placeholder=''), initChatbot_btn.update(interactive=True)
|
135 |
|
136 |
# initialize chatbot function sets the QA Chain, and also sets/updates any other components to start chatting. updateQaChain function only updates QA chain and will be called whenever Adv Settings are updated.
|
137 |
def initializeChatbot(temp, k, modelNameDD, stdlQs, api_key_st, vsDict_st, progress=gr.Progress()):
|
|
|
249 |
, placeholder=url_tb_ph)
|
250 |
data_ingest_btn = gr.Button("Load Data")
|
251 |
status_tb = gr.TextArea(label='Status Info')
|
252 |
+
initChatbot_btn = gr.Button("Initialize Chatbot", variant="primary", interactive=False)
|
253 |
|
254 |
credComps_btn_tb = [oaiKey_tb, oaiKey_btn, bamKey_tb, bamKey_btn, wxKey_tb, wxPid_tb, wxKey_btn]
|
255 |
credComps_op = [status_tb] + credComps_btn_tb + [api_key_state]
|
|
|
268 |
temp_sld = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", info='Sampling temperature to use when calling LLM. Defaults to 0.7')
|
269 |
k_sld = gr.Slider(minimum=1, maximum=10, step=1, value=mode.k, label="K", info='Number of relavant documents to return from Vector Store. Defaults to 4')
|
270 |
model_dd = gr.Dropdown(label='Model Name'\
|
271 |
+
, choices=getModelChoices(openAi_models, ModelTypes, bam_models_old), allow_custom_value=True\
|
272 |
, info=model_dd_info)
|
273 |
stdlQs_rb = gr.Radio(label='Standalone Question', info=stdlQs_rb_info\
|
274 |
, type='index', value=stdlQs_rb_choices[1]\
|
|
|
282 |
oaiKey_tb.submit(**oaiKey_btn_args)
|
283 |
|
284 |
# BAM API button
|
285 |
+
bamKey_btn_args = {'fn':setBamApiKey, 'inputs':[bamKey_tb], 'outputs':credComps_op+[model_dd]}
|
286 |
bamKey_btn.click(**bamKey_btn_args)
|
287 |
bamKey_tb.submit(**bamKey_btn_args)
|
288 |
|
|
|
291 |
wxKey_btn.click(**wxKey_btn_args)
|
292 |
|
293 |
# Data Ingest Button
|
294 |
+
data_ingest_event = data_ingest_btn.click(uiData_vecStore, [upload_fb, urls_tb, api_key_state, chromaVS_state], [chromaVS_state, status_tb, data_ingest_btn, upload_fb, urls_tb, initChatbot_btn])
|
295 |
|
296 |
# Adv Settings
|
297 |
advSet_args = {'fn':updateQaChain, 'inputs':[temp_sld, k_sld, model_dd, stdlQs_rb, api_key_state, chromaVS_state], 'outputs':[qa_state, model_dd]}
|
ttyd_consts.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from langchain import PromptTemplate
|
2 |
-
from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
5 |
load_dotenv()
|
@@ -45,7 +44,7 @@ Question: {question} [/INST]
|
|
45 |
|
46 |
promptLlama=PromptTemplate(input_variables=['context', 'question'], template=llamaPromptTemplate)
|
47 |
|
48 |
-
|
49 |
'salesforce/codegen2-16b',
|
50 |
'codellama/codellama-34b-instruct',
|
51 |
'tiiuae/falcon-40b',
|
@@ -71,9 +70,9 @@ bam_models = sorted(['bigscience/bloom',
|
|
71 |
'bigcode/starcoder',
|
72 |
'google/ul2'])
|
73 |
|
74 |
-
|
75 |
|
76 |
-
|
77 |
|
78 |
|
79 |
OaiDefaultModel = 'gpt-3.5-turbo (openai)'
|
|
|
1 |
from langchain import PromptTemplate
|
|
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
load_dotenv()
|
|
|
44 |
|
45 |
promptLlama=PromptTemplate(input_variables=['context', 'question'], template=llamaPromptTemplate)
|
46 |
|
47 |
+
bam_models_old = sorted(['bigscience/bloom',
|
48 |
'salesforce/codegen2-16b',
|
49 |
'codellama/codellama-34b-instruct',
|
50 |
'tiiuae/falcon-40b',
|
|
|
70 |
'bigcode/starcoder',
|
71 |
'google/ul2'])
|
72 |
|
73 |
+
openAi_models = ['gpt-3.5-turbo (openai)', 'gpt-3.5-turbo-16k (openai)', 'gpt-4 (openai)', 'text-davinci-003 (Legacy - openai)', 'text-curie-001 (Legacy - openai)', 'babbage-002 (openai)']
|
74 |
|
75 |
+
model_dd_info = 'Make sure your credentials are submitted before changing the model. You can also input any OpenAI model name or Watsonx/BAM model ID.'
|
76 |
|
77 |
|
78 |
OaiDefaultModel = 'gpt-3.5-turbo (openai)'
|
ttyd_functions.py
CHANGED
@@ -372,3 +372,6 @@ def changeModel(oldModel, newModel):
|
|
372 |
gr.Warning(warning)
|
373 |
time.sleep(1)
|
374 |
return newModel
|
|
|
|
|
|
|
|
372 |
gr.Warning(warning)
|
373 |
time.sleep(1)
|
374 |
return newModel
|
375 |
+
|
376 |
+
def getModelChoices(openAi_models, wml_models, bam_models):
|
377 |
+
return [model for model in openAi_models] + [model.value+' (watsonx)' for model in wml_models] + [model + ' (bam)' for model in bam_models]
|