Chat_ / app.py
linl03's picture
Update app.py
284dcd9 verified
import gradio as gr
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_community.tools import DuckDuckGoSearchRun
from langchain.prompts import PromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
import pickle
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import END, StateGraph
from huggingface_hub import hf_hub_download
from langchain_community.llms import LlamaCpp
wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
web_search_tool = DuckDuckGoSearchRun(api_wrapper=wrapper)
llm = LlamaCpp(
model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
temperature=0,
max_tokens=512,
n_ctx = 2000,
top_p=1,
# callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
)
chat_history = list()
try:
with open("template.pkl", 'rb') as file:
template_abox = pickle.load(file)
except:
hf_hub_download(repo_id="linl03/dataAboxChat",local_dir="./", filename="template.pkl", repo_type="dataset")
with open("./template.pkl", 'rb') as file:
template_abox = pickle.load(file)
router_prompt = PromptTemplate(
template=template_abox["router_template"],
input_variables=["question"],
)
generate_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
template_abox["system_prompt"],
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
query_prompt = PromptTemplate(
template=template_abox["query_template"],
input_variables=["question"],
)
remind_prompt = PromptTemplate(
template=template_abox["schedule_template"],
input_variables=["time"],
)
question_router = router_prompt | llm | JsonOutputParser()
generate_chain = generate_prompt | llm | StrOutputParser()
query_chain = query_prompt | llm | JsonOutputParser()
# llm_chain = nomalqa_prompt | llm | StrOutputParser()
class State(TypedDict):
question : str
generation : str
search_query : str
context : str
def generate(state):
print("Step: Đang tạo câu trả lời")
question = state["question"]
context = state["context"]
# return question, context
return {'question': question, 'context': context}
def transform_query(state):
print("Step: Tối ưu câu hỏi của người dùng")
question = state['question']
gen_query = query_chain.invoke({"question": question})
print(gen_query)
search_query = gen_query["query"]
return {"search_query": search_query}
def web_search(state):
search_query = state['search_query']
print(f'Step: Đang tìm kiếm web cho: "{search_query}"')
# Web search tool call
search_result = web_search_tool.invoke(search_query)
print("Search result:", search_result)
return {"context": search_result}
def route_question(state):
print("Step: Routing Query")
question = state['question']
output = question_router.invoke({"question": question})
print('Lựa chọn của AI là: ', output)
if output['choice'] == "web_search":
# print("Step: Routing Query to Web Search")
return "websearch"
elif output['choice'] == 'generate':
# print("Step: Routing Query to Generation")
return "generate"
def Agents():
workflow = StateGraph(State)
workflow.add_node("websearch", web_search)
workflow.add_node("transform_query", transform_query)
workflow.add_node("generate", generate)
# Build the edges
workflow.set_conditional_entry_point(
route_question,
{
"websearch": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "websearch")
workflow.add_edge("websearch", "generate")
workflow.add_edge("generate", END)
# Compile the workflow
return workflow.compile()
def QA(question: str, history: list):
# print(question.text, question.files, history, type)
local_agent = Agents()
gr.Info("Đang tạo câu trả lời!")
response = ''
output = local_agent.invoke({"question": question})
context = output['context']
questions = output['question']
for chunk in generate_chain.stream({"context": context, "question": questions, "chat_history": chat_history}):
response += chunk
print(chunk, end="", flush=True)
yield response
chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=response))
demo = gr.ChatInterface(
QA,
fill_height=True,
multimodal=True,
title="Box Chat(Agent)",
)
if __name__ == "__main__":
demo.launch()