|
import json |
|
import logging |
|
import textwrap |
|
import uuid |
|
|
|
from ollama import Client |
|
|
|
from modules.presets import i18n |
|
|
|
from ..index_func import construct_index |
|
from ..utils import count_token |
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class OllamaClient(BaseLLMModel): |
|
def __init__(self, model_name, user_name="", ollama_host="", backend_model="") -> None: |
|
super().__init__(model_name=model_name, user=user_name) |
|
self.backend_model = backend_model |
|
self.ollama_host = ollama_host |
|
self.update_token_limit() |
|
|
|
def get_model_list(self): |
|
client = Client(host=self.ollama_host) |
|
return client.list() |
|
|
|
def update_token_limit(self): |
|
lower_model_name = self.backend_model.lower() |
|
if "mistral" in lower_model_name: |
|
self.token_upper_limit = 8*1024 |
|
elif "gemma" in lower_model_name: |
|
self.token_upper_limit = 8*1024 |
|
elif "codellama" in lower_model_name: |
|
self.token_upper_limit = 4*1024 |
|
elif "llama2-chinese" in lower_model_name: |
|
self.token_upper_limit = 4*1024 |
|
elif "llama2" in lower_model_name: |
|
self.token_upper_limit = 4*1024 |
|
elif "mixtral" in lower_model_name: |
|
self.token_upper_limit = 32*1024 |
|
elif "llava" in lower_model_name: |
|
self.token_upper_limit = 4*1024 |
|
|
|
def get_answer_stream_iter(self): |
|
if self.backend_model == "": |
|
return i18n("请先选择Ollama后端模型\n\n") |
|
client = Client(host=self.ollama_host) |
|
response = client.chat(model=self.backend_model, messages=self.history,stream=True) |
|
partial_text = "" |
|
for i in response: |
|
response = i['message']['content'] |
|
partial_text += response |
|
yield partial_text |
|
self.all_token_counts[-1] = count_token(partial_text) |
|
yield partial_text |
|
|