Spaces:
Running
Running
import gradio as gr | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import GPT4AllEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema.runnable import RunnablePassthrough | |
# from langchain.prompts import ChatPromptTemplate | |
# from langchain_community.chat_models import ChatOllama | |
from prompt_template import * | |
from langgraph.graph import END, StateGraph | |
from langchain_community.llms import LlamaCpp | |
# local_llm = 'aleni_ox' | |
# llm = ChatOllama(model=local_llm, | |
# keep_alive="3h", | |
# max_tokens=512, | |
# temperature=0, | |
# # callbacks=[StreamingStdOutCallbackHandler()] | |
# ) | |
llm = LlamaCpp( | |
model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf", | |
temperature=0, | |
max_tokens=512, | |
top_p=1, | |
# callback_manager=callback_manager, | |
verbose=True, # Verbose is required to pass to the callback manager | |
) | |
question_router = router_prompt | llm | JsonOutputParser() | |
generate_chain = generate_prompt | llm | StrOutputParser() | |
query_chain = query_prompt | llm | JsonOutputParser() | |
llm_chain = nomalqa_prompt | llm | StrOutputParser() | |
def generate(state): | |
""" | |
Generate answer | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation, that contains LLM generation | |
""" | |
print("Step: Đang tạo câu trả lời từ những gì tìm được") | |
question = state["question"] | |
context = state["context"] | |
# return question, context | |
return {'question': question, 'context': context} | |
# respon='' | |
# for chunk in generate_chain.stream({"context": context, "question": question}): | |
# respon += chunk | |
# print(chunk, end="", flush=True) | |
def transform_query(state): | |
""" | |
Transform user question to web search | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Appended search query | |
""" | |
print("Step: Tối ưu câu hỏi của người dùng") | |
question = state['question'] | |
gen_query = query_chain.invoke({"question": question}) | |
search_query = gen_query["query"] | |
return {"search_query": search_query} | |
def web_search(state): | |
""" | |
Web search based on the question | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Appended web results to context | |
""" | |
search_query = state['search_query'] | |
print(f'Step: Đang tìm kiếm web cho: "{search_query}"') | |
# Web search tool call | |
search_result = web_search_tool.invoke(search_query) | |
print("Search result:", search_result) | |
return {"context": search_result} | |
def route_question(state): | |
""" | |
route question to web search or generation. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Next node to call | |
""" | |
print("Step: Routing Query") | |
question = state['question'] | |
output = question_router.invoke({"question": question}) | |
print('Lựa chọn của AI là: ', output) | |
if output == "web_search": | |
# print("Step: Routing Query to Web Search") | |
return "websearch" | |
elif output == 'generate': | |
# print("Step: Routing Query to Generation") | |
return "generate" | |
workflow = StateGraph(State) | |
workflow.add_node("websearch", web_search) | |
workflow.add_node("transform_query", transform_query) | |
workflow.add_node("generate", generate) | |
# Build the edges | |
workflow.set_conditional_entry_point( | |
route_question, | |
{ | |
"websearch": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "websearch") | |
workflow.add_edge("websearch", "generate") | |
workflow.add_edge("generate", END) | |
# Compile the workflow | |
local_agent = workflow.compile() | |
def run_agent(query): | |
local_agent.invoke({"question": query}) | |
print("=======") | |
def QA(question: str, history: list, type: str): | |
if 'Agent' in type: | |
gr.Info("Đang tạo câu trả lời!") | |
respon = '' | |
# print(question) | |
output = local_agent.invoke({"question": question}) | |
# print(output) | |
context = output['context'] | |
questions = output['question'] | |
for chunk in generate_chain.stream({"context": context, "question": questions}): | |
respon += chunk | |
print(chunk, end="", flush=True) | |
yield respon | |
else: | |
gr.Info("Đang tạo câu trả lời!") | |
print(question, history) | |
respon = '' | |
for chunk in llm_chain.stream(question): | |
respon += chunk | |
print(chunk, end="", flush=True) | |
yield respon | |
def create_db(doc: str) -> str: | |
loader = PyPDFLoader(doc) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40) | |
chunked_documents = loader.load_and_split(text_splitter) | |
embedding_model = GPT4AllEmbeddings(model_name="all-MiniLM-L6-v2.gguf2.f16.gguf", gpt4all_kwargs={'allow_download': 'True'}) | |
db = FAISS.from_documents(chunked_documents, embedding_model) | |
gr.Info("Đã tải lên dữ liệu từ PDF!") | |
retriever = db.as_retriever( | |
search_type="similarity", | |
search_kwargs= {"k": 3} | |
) | |
llm_chain = ( | |
{ | |
"context": retriever, | |
"question": RunnablePassthrough()} | |
| nomaldoc_prompt | |
| llm | |
) | |
with gr.Blocks(fill_height=True) as demo: | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
democ2 = gr.Interface( | |
create_db, | |
[gr.File(file_count='single')], | |
None, | |
) | |
with gr.Column(scale=2): | |
democ1 = gr.ChatInterface( | |
QA, | |
additional_inputs=[gr.Dropdown(["None", "Agent", "Doc"], label="Type", info="Chọn một kiểu chat!"),] | |
) | |
if __name__ == "__main__": | |
demo.launch() |