Spaces:
Running
Running
import gradio as gr | |
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain.prompts import PromptTemplate, MessagesPlaceholder | |
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser | |
from typing_extensions import TypedDict | |
from langchain_core.prompts import ChatPromptTemplate | |
import pickle | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langgraph.graph import END, StateGraph | |
from huggingface_hub import hf_hub_download | |
from langchain_community.llms import LlamaCpp | |
wrapper = DuckDuckGoSearchAPIWrapper(max_results=5) | |
web_search_tool = DuckDuckGoSearchRun(api_wrapper=wrapper) | |
llm = LlamaCpp( | |
model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf", | |
temperature=0, | |
max_tokens=512, | |
n_ctx = 2000, | |
top_p=1, | |
# callback_manager=callback_manager, | |
verbose=True, # Verbose is required to pass to the callback manager | |
) | |
chat_history = list() | |
try: | |
with open("template.pkl", 'rb') as file: | |
template_abox = pickle.load(file) | |
except: | |
hf_hub_download(repo_id="linl03/dataAboxChat",local_dir="./", filename="template.pkl", repo_type="dataset") | |
with open("./template.pkl", 'rb') as file: | |
template_abox = pickle.load(file) | |
router_prompt = PromptTemplate( | |
template=template_abox["router_template"], | |
input_variables=["question"], | |
) | |
generate_prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
template_abox["system_prompt"], | |
), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}"), | |
] | |
) | |
query_prompt = PromptTemplate( | |
template=template_abox["query_template"], | |
input_variables=["question"], | |
) | |
remind_prompt = PromptTemplate( | |
template=template_abox["schedule_template"], | |
input_variables=["time"], | |
) | |
question_router = router_prompt | llm | JsonOutputParser() | |
generate_chain = generate_prompt | llm | StrOutputParser() | |
query_chain = query_prompt | llm | JsonOutputParser() | |
# llm_chain = nomalqa_prompt | llm | StrOutputParser() | |
class State(TypedDict): | |
question : str | |
generation : str | |
search_query : str | |
context : str | |
def generate(state): | |
print("Step: Đang tạo câu trả lời") | |
question = state["question"] | |
context = state["context"] | |
# return question, context | |
return {'question': question, 'context': context} | |
def transform_query(state): | |
print("Step: Tối ưu câu hỏi của người dùng") | |
question = state['question'] | |
gen_query = query_chain.invoke({"question": question}) | |
print(gen_query) | |
search_query = gen_query["query"] | |
return {"search_query": search_query} | |
def web_search(state): | |
search_query = state['search_query'] | |
print(f'Step: Đang tìm kiếm web cho: "{search_query}"') | |
# Web search tool call | |
search_result = web_search_tool.invoke(search_query) | |
print("Search result:", search_result) | |
return {"context": search_result} | |
def route_question(state): | |
print("Step: Routing Query") | |
question = state['question'] | |
output = question_router.invoke({"question": question}) | |
print('Lựa chọn của AI là: ', output) | |
if output['choice'] == "web_search": | |
# print("Step: Routing Query to Web Search") | |
return "websearch" | |
elif output['choice'] == 'generate': | |
# print("Step: Routing Query to Generation") | |
return "generate" | |
def Agents(): | |
workflow = StateGraph(State) | |
workflow.add_node("websearch", web_search) | |
workflow.add_node("transform_query", transform_query) | |
workflow.add_node("generate", generate) | |
# Build the edges | |
workflow.set_conditional_entry_point( | |
route_question, | |
{ | |
"websearch": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "websearch") | |
workflow.add_edge("websearch", "generate") | |
workflow.add_edge("generate", END) | |
# Compile the workflow | |
return workflow.compile() | |
def QA(question: str, history: list): | |
# print(question.text, question.files, history, type) | |
local_agent = Agents() | |
gr.Info("Đang tạo câu trả lời!") | |
response = '' | |
output = local_agent.invoke({"question": question}) | |
context = output['context'] | |
questions = output['question'] | |
for chunk in generate_chain.stream({"context": context, "question": questions, "chat_history": chat_history}): | |
response += chunk | |
print(chunk, end="", flush=True) | |
yield response | |
chat_history.append(HumanMessage(content=question)) | |
chat_history.append(AIMessage(content=response)) | |
demo = gr.ChatInterface( | |
QA, | |
fill_height=True, | |
multimodal=True, | |
title="Box Chat(Agent)", | |
) | |
if __name__ == "__main__": | |
demo.launch() |