jonathanjordan21 commited on
Commit
ee910d2
1 Parent(s): 7b40e67

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +78 -4
apis/chat_api.py CHANGED
@@ -140,6 +140,80 @@ class ChatAPIApp:
140
  raise HTTPException(status_code=e.status_code, detail=e.detail)
141
  except Exception as e:
142
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  class EmbeddingRequest(BaseModel):
@@ -181,19 +255,19 @@ class ChatAPIApp:
181
 
182
  self.app.post(
183
  prefix + "/chat/completions",
184
- summary="Chat completions in conversation session",
185
  include_in_schema=include_in_schema,
186
  )(self.chat_completions)
187
 
188
  self.app.post(
189
  prefix + "/generate",
190
- summary="Chat completions in conversation session",
191
  include_in_schema=include_in_schema,
192
- )(self.chat_completions)
193
 
194
  self.app.post(
195
  prefix + "/chat",
196
- summary="Chat completions in conversation session",
197
  include_in_schema=include_in_schema,
198
  )(self.chat_completions)
199
 
 
140
  raise HTTPException(status_code=e.status_code, detail=e.detail)
141
  except Exception as e:
142
  raise HTTPException(status_code=500, detail=str(e))
143
+
144
+
145
+ class GenerateRequest(BaseModel):
146
+ model: str = Field(
147
+ default="nous-mixtral-8x7b",
148
+ description="(str) `nous-mixtral-8x7b`",
149
+ )
150
+ prompt: str = Field(
151
+ default="Hello, who are you?",
152
+ description="(list) Messages",
153
+ )
154
+ temperature: Union[float, None] = Field(
155
+ default=0.5,
156
+ description="(float) Temperature",
157
+ )
158
+ top_p: Union[float, None] = Field(
159
+ default=0.95,
160
+ description="(float) top p",
161
+ )
162
+ max_tokens: Union[int, None] = Field(
163
+ default=-1,
164
+ description="(int) Max tokens",
165
+ )
166
+ use_cache: bool = Field(
167
+ default=False,
168
+ description="(bool) Use cache",
169
+ )
170
+ stream: bool = Field(
171
+ default=True,
172
+ description="(bool) Stream",
173
+ )
174
+
175
+ def generate_text(
176
+ self, item: GenerateRequest, api_key: str = Depends(extract_api_key)
177
+ ):
178
+ try:
179
+ api_key = self.auth_api_key(api_key)
180
+
181
+ if item.model == "gpt-3.5-turbo":
182
+ streamer = OpenaiStreamer()
183
+ stream_response = streamer.chat_response(messages=[{"user":item.prompt}])
184
+ elif item.model in PRO_MODELS:
185
+ streamer = HuggingchatStreamer(model=item.model)
186
+ stream_response = streamer.chat_response(
187
+ messages=[{"user":item.prompt}],
188
+ )
189
+ else:
190
+ streamer = HuggingfaceStreamer(model=item.model)
191
+ stream_response = streamer.chat_response(
192
+ prompt=item.prompt,
193
+ temperature=item.temperature,
194
+ top_p=item.top_p,
195
+ max_new_tokens=item.max_tokens,
196
+ api_key=api_key,
197
+ use_cache=item.use_cache,
198
+ )
199
+
200
+ if item.stream:
201
+ event_source_response = EventSourceResponse(
202
+ streamer.chat_return_generator(stream_response),
203
+ media_type="text/event-stream",
204
+ ping=2000,
205
+ ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
206
+ )
207
+ return event_source_response
208
+ else:
209
+ data_response = streamer.chat_return_dict(stream_response)
210
+ return data_response
211
+ except HfApiException as e:
212
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
213
+ except Exception as e:
214
+ raise HTTPException(status_code=500, detail=str(e))
215
+
216
+
217
 
218
 
219
  class EmbeddingRequest(BaseModel):
 
255
 
256
  self.app.post(
257
  prefix + "/chat/completions",
258
+ summary="OpenAI Chat completions in conversation session",
259
  include_in_schema=include_in_schema,
260
  )(self.chat_completions)
261
 
262
  self.app.post(
263
  prefix + "/generate",
264
+ summary="Ollama text generation",
265
  include_in_schema=include_in_schema,
266
+ )(self.generate_text)
267
 
268
  self.app.post(
269
  prefix + "/chat",
270
+ summary="Ollama Chat completions in conversation session",
271
  include_in_schema=include_in_schema,
272
  )(self.chat_completions)
273