arslan-ahmed commited on
Commit
8022f91
1 Parent(s): 7869db9

Update ttyd_functions.py

Browse files
Files changed (1) hide show
  1. ttyd_functions.py +8 -5
ttyd_functions.py CHANGED
@@ -84,15 +84,18 @@ def getOaiLlm(temp, modelNameDD, api_key_st):
84
  return llm
85
 
86
 
 
 
 
 
87
  def getWxLlm(temp, modelNameDD, api_key_st):
88
  modelName = modelNameDD.split('(')[0].strip()
89
  wxModelParams = {
90
  GenParams.DECODING_METHOD: DecodingMethods.SAMPLE,
91
- GenParams.MAX_NEW_TOKENS: 1000,
92
- GenParams.MIN_NEW_TOKENS: 1,
93
  GenParams.TEMPERATURE: float(temp),
94
- GenParams.TOP_K: 50,
95
- GenParams.TOP_P: 1
96
  }
97
  model = Model(
98
  model_id=modelName,
@@ -104,7 +107,7 @@ def getWxLlm(temp, modelNameDD, api_key_st):
104
 
105
  def getBamLlm(temp, modelNameDD, api_key_st):
106
  modelName = modelNameDD.split('(')[0].strip()
107
- parameters = GenerateParams(decoding_method="sample", max_new_tokens=1024, temperature=float(temp), top_k=50, top_p=1)
108
  llm = LangChainInterface(model=modelName, params=parameters, credentials=api_key_st['bam_creds'])
109
  return llm
110
 
 
84
  return llm
85
 
86
 
87
+ MAX_NEW_TOKENS = 1024
88
+ TOP_K = None
89
+ TOP_P = 1
90
+
91
  def getWxLlm(temp, modelNameDD, api_key_st):
92
  modelName = modelNameDD.split('(')[0].strip()
93
  wxModelParams = {
94
  GenParams.DECODING_METHOD: DecodingMethods.SAMPLE,
95
+ GenParams.MAX_NEW_TOKENS: MAX_NEW_TOKENS,
 
96
  GenParams.TEMPERATURE: float(temp),
97
+ GenParams.TOP_K: TOP_K,
98
+ GenParams.TOP_P: TOP_P
99
  }
100
  model = Model(
101
  model_id=modelName,
 
107
 
108
  def getBamLlm(temp, modelNameDD, api_key_st):
109
  modelName = modelNameDD.split('(')[0].strip()
110
+ parameters = GenerateParams(decoding_method="sample", max_new_tokens=MAX_NEW_TOKENS, temperature=float(temp), top_k=TOP_K, top_p=TOP_P)
111
  llm = LangChainInterface(model=modelName, params=parameters, credentials=api_key_st['bam_creds'])
112
  return llm
113