jonathanjordan21 commited on
Commit
fb1e6a9
1 Parent(s): 162773e

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +29 -6
apis/chat_api.py CHANGED
@@ -335,6 +335,22 @@ class ChatAPIApp:
335
  options: Optional[dict] = None
336
 
337
  def get_embeddings(self, request: EmbeddingRequest, api_key: str = Depends(extract_api_key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  try:
339
  model = request.model
340
  model_kwargs = request.options
@@ -383,12 +399,19 @@ class ChatAPIApp:
383
  summary="Ollama Chat completions in conversation session",
384
  include_in_schema=include_in_schema,
385
  )(self.chat_completions_ollama)
386
-
387
- self.app.post(
388
- prefix + "/embeddings",
389
- summary="Get Embeddings with prompt",
390
- include_in_schema=include_in_schema,
391
- )(self.get_embeddings)
 
 
 
 
 
 
 
392
 
393
  self.app.get(
394
  "/api/tags",
 
335
  options: Optional[dict] = None
336
 
337
  def get_embeddings(self, request: EmbeddingRequest, api_key: str = Depends(extract_api_key)):
338
+ try:
339
+ model = request.model
340
+ model_kwargs = request.options
341
+ embeddings = self.embeddings[model].encode(request.prompt, convert_to_tensor=True)#, **model_kwargs)
342
+ return {
343
+ "object":"list",
344
+ "data":[
345
+ "object": "embedding", "index": 0, "embedding": embeddings.tolist()
346
+ ],
347
+ "model": model,
348
+ "usage":{},
349
+ }
350
+ except ValueError as e:
351
+ raise HTTPException(status_code=400, detail=str(e))
352
+
353
+ def get_embeddings_ollama(self, request: EmbeddingRequest, api_key: str = Depends(extract_api_key)):
354
  try:
355
  model = request.model
356
  model_kwargs = request.options
 
399
  summary="Ollama Chat completions in conversation session",
400
  include_in_schema=include_in_schema,
401
  )(self.chat_completions_ollama)
402
+
403
+ if prefix in ["/api"]:
404
+ self.app.post(
405
+ prefix + "/embeddings",
406
+ summary="Ollama Get Embeddings with prompt",
407
+ include_in_schema=True,
408
+ )(self.get_embeddings_ollama)
409
+ else:
410
+ self.app.post(
411
+ prefix + "/embeddings",
412
+ summary="Get Embeddings with prompt",
413
+ include_in_schema=include_in_schema,
414
+ )(self.get_embeddings)
415
 
416
  self.app.get(
417
  "/api/tags",