Chat_ / app.py
linl03's picture
Update app.py
85a3023 verified
raw
history blame
6.14 kB
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema.runnable import RunnablePassthrough
# from langchain.prompts import ChatPromptTemplate
# from langchain_community.chat_models import ChatOllama
from prompt_template import *
from langgraph.graph import END, StateGraph
from langchain_community.llms import LlamaCpp
# local_llm = 'aleni_ox'
# llm = ChatOllama(model=local_llm,
# keep_alive="3h",
# max_tokens=512,
# temperature=0,
# # callbacks=[StreamingStdOutCallbackHandler()]
# )
llm = LlamaCpp(
model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
temperature=0,
max_tokens=512,
top_p=1,
# callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
)
question_router = router_prompt | llm | JsonOutputParser()
generate_chain = generate_prompt | llm | StrOutputParser()
query_chain = query_prompt | llm | JsonOutputParser()
llm_chain = nomalqa_prompt | llm | StrOutputParser()
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("Step: Đang tạo câu trả lời từ những gì tìm được")
question = state["question"]
context = state["context"]
# return question, context
return {'question': question, 'context': context}
# respon=''
# for chunk in generate_chain.stream({"context": context, "question": question}):
# respon += chunk
# print(chunk, end="", flush=True)
def transform_query(state):
"""
Transform user question to web search
Args:
state (dict): The current graph state
Returns:
state (dict): Appended search query
"""
print("Step: Tối ưu câu hỏi của người dùng")
question = state['question']
gen_query = query_chain.invoke({"question": question})
search_query = gen_query["query"]
return {"search_query": search_query}
def web_search(state):
"""
Web search based on the question
Args:
state (dict): The current graph state
Returns:
state (dict): Appended web results to context
"""
search_query = state['search_query']
print(f'Step: Đang tìm kiếm web cho: "{search_query}"')
# Web search tool call
search_result = web_search_tool.invoke(search_query)
print("Search result:", search_result)
return {"context": search_result}
def route_question(state):
"""
route question to web search or generation.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
print("Step: Routing Query")
question = state['question']
output = question_router.invoke({"question": question})
print('Lựa chọn của AI là: ', output)
if output == "web_search":
# print("Step: Routing Query to Web Search")
return "websearch"
elif output == 'generate':
# print("Step: Routing Query to Generation")
return "generate"
workflow = StateGraph(State)
workflow.add_node("websearch", web_search)
workflow.add_node("transform_query", transform_query)
workflow.add_node("generate", generate)
# Build the edges
workflow.set_conditional_entry_point(
route_question,
{
"websearch": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "websearch")
workflow.add_edge("websearch", "generate")
workflow.add_edge("generate", END)
# Compile the workflow
local_agent = workflow.compile()
def run_agent(query):
local_agent.invoke({"question": query})
print("=======")
def QA(question: str, history: list, type: str):
if 'Agent' in type:
gr.Info("Đang tạo câu trả lời!")
respon = ''
# print(question)
output = local_agent.invoke({"question": question})
# print(output)
context = output['context']
questions = output['question']
for chunk in generate_chain.stream({"context": context, "question": questions}):
respon += chunk
print(chunk, end="", flush=True)
yield respon
else:
gr.Info("Đang tạo câu trả lời!")
print(question, history)
respon = ''
for chunk in llm_chain.stream(question):
respon += chunk
print(chunk, end="", flush=True)
yield respon
def create_db(doc: str) -> str:
loader = PyPDFLoader(doc)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
chunked_documents = loader.load_and_split(text_splitter)
embedding_model = GPT4AllEmbeddings(model_name="all-MiniLM-L6-v2.gguf2.f16.gguf", gpt4all_kwargs={'allow_download': 'True'})
db = FAISS.from_documents(chunked_documents, embedding_model)
gr.Info("Đã tải lên dữ liệu từ PDF!")
retriever = db.as_retriever(
search_type="similarity",
search_kwargs= {"k": 3}
)
llm_chain = (
{
"context": retriever,
"question": RunnablePassthrough()}
| nomaldoc_prompt
| llm
)
with gr.Blocks(fill_height=True) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=1):
democ2 = gr.Interface(
create_db,
[gr.File(file_count='single')],
None,
)
with gr.Column(scale=2):
democ1 = gr.ChatInterface(
QA,
additional_inputs=[gr.Dropdown(["None", "Agent", "Doc"], label="Type", info="Chọn một kiểu chat!"),]
)
if __name__ == "__main__":
demo.launch()