RAGchat / app /llm.py
moriire's picture
Update app/llm.py
3e76b56 verified
raw
history blame
6.65 kB
import fastapi
from fastapi.responses import JSONResponse
from fastapi_users import schemas
from time import time
#from fastapi.middleware.cors import CORSMiddleware
#MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
import logging
from langchain_community.llms import LlamaCpp
import llama_cpp
import llama_cpp.llama_tokenizer
from pydantic import BaseModel
from fastapi import APIRouter
from app.users import current_active_user
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain import hub
from langchain_core.runnables import RunnablePassthrough, RunnablePick
rag_prompt_llama = hub.pull("rlm/rag-prompt-llama")
rag_prompt.messages
llm = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"),
verbose=False,
n_ctx=512,
n_gpu_layers=0,
#chat_format="llama-2"
)
class RagChat:
def agent(self):
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
return all_splits
def download_embedding(self):
vectorstore = Chroma.from_documents(documents=self.agent, embedding=GPT4AllEmbeddings())
return vectorstore
def chat(self, question):
retriever = vectorstore.as_retriever()
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| rag_prompt_llama
| llm
| StrOutputParser()
)
return chain.invoke({"context": self.search(question), "question": question})
def search(self, question):
docs = self.download_embedding().similarity_search(question)
return docs
class GenModel(BaseModel):
question: str
system: str = "You are a helpful medical AI chat assistant. Help as much as you can.Also continuously ask for possible symptoms in order to atat a conclusive ailment or sickness and possible solutions.Remember, response in English."
temperature: float = 0.8
seed: int = 101
mirostat_mode: int=2
mirostat_tau: float=4.0
mirostat_eta: float=1.1
class ChatModel(BaseModel):
question: list
system: str = "You are chatDoctor, a helpful health and medical assistant. You are chatting with a human. Help as much as you can. Also continuously ask for possible symptoms in order to a conclusive ailment or sickness and possible solutions.Remember, response in English."
temperature: float = 0.8
seed: int = 101
mirostat_mode: int=2
mirostat_tau: float=4.0
mirostat_eta: float=1.1
llm_chat = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"),
verbose=False,
n_ctx=512,
n_gpu_layers=0,
#chat_format="llama-2"
)
llm_generate = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
#tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat"),
verbose=False,
n_ctx=4096,
n_gpu_layers=0,
mirostat_mode=2,
mirostat_tau=4.0,
mirostat_eta=1.1,
#chat_format="llama-2"
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
#app = fastapi.FastAPI(
#title="OpenGenAI",
#description="Your Excellect AI Physician")
"""
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
"""
llm_router = APIRouter(prefix="/llm")
@llm_router.get("/health", tags=["llm"])
def health():
return {"status": "ok"}
@llm_router.post("/rag/", tags=["llm"])
async def ragchat(chatm:ChatModel):#, user: schemas.BaseUser = fastapi.Depends(current_active_user)):
r = RagChat().chat(chatml.question)
print(r)
# Chat Completion API
@llm_router.post("/chat/", tags=["llm"])
async def chat(chatm:ChatModel):#, user: schemas.BaseUser = fastapi.Depends(current_active_user)):
#chatm.system = chatm.system.format("")#user.email)
try:
st = time()
output = llm_chat.create_chat_completion(
messages = chatm.question,
temperature = chatm.temperature,
seed = chatm.seed,
#stream=True
)
print(output)
#print(output)
et = time()
output["time"] = et - st
#messages.append({'role': "assistant", "content": output['choices'][0]['message']['content']})
#print(messages)
return output
except Exception as e:
logger.error(f"Error in /complete endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)
# Chat Completion API
@llm_router.post("/generate", tags=["llm"])
async def generate(gen:GenModel):#, user: schemas.BaseUser = fastapi.Depends(current_active_user)):
gen.system = "You are an helpful medical AI assistant."
gen.temperature = 0.5
gen.seed = 42
try:
#st = time()
output = llm_generate.create_completion(
#messages=[
# {"role": "system", "content": gen.system},
# {"role": "user", "content": gen.question},
# ],
gen.question,
temperature = gen.temperature,
seed= gen.seed,
#chat_format="llama-2",
stream=True,
echo = True
)
for chunk in output:
delta = chunk['choices'][0]#['delta']
print(delta)
if 'role' in delta:
print(delta['role'], end=': ')
elif 'content' in delta:
print(delta['content'], end='')
#print(chunk)
#et = time()
#output["time"] = et - st
#print(output)
except Exception as e:
logger.error(f"Error in /generate endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)