C_METROPOLE / app.py
kheopss's picture
Update app.py
bea5856 verified
raw
history blame
3.92 kB
import gradio as gr
from huggingface_hub import InferenceClient
from llama_index.core import Document
import pandas as pd
import getpass
import os
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import VectorStoreIndex
from llama_index.core import QueryBundle
from IPython.display import display, HTML
from llama_index.core.postprocessor import LLMRerank
import logging
import sys
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.llms.openai import OpenAI
# Load the JSON file
file_path = 'response_metropol.json'
data = pd.read_json(file_path)
data.head()
documents = [Document(text=row['values'],metadata={"filename": row['file_name'], "description":row['file_description']},) for index, row in data.iterrows()]
os.environ["OPENAI_API_KEY"] = os.getenv('api_key_openai')
#pd.set_option("display.max_colwidth", -1)
# build index
index = VectorStoreIndex.from_documents(documents)
def get_retrieved_nodes(
query_str, vector_top_k=10, reranker_top_n=5, with_reranker=False
):
query_bundle = QueryBundle(query_str)
# configure retriever
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=vector_top_k,
)
retrieved_nodes = retriever.retrieve(query_bundle)
if with_reranker:
# configure reranker
reranker = LLMRerank(
choice_batch_size=5,
top_n=reranker_top_n,
)
retrieved_nodes = reranker.postprocess_nodes(
retrieved_nodes, query_bundle
)
return retrieved_nodes
def pretty_print(df):
return display(HTML(df.to_html().replace("\\n", "")))
def visualize_retrieved_nodes(nodes) -> None:
result_dicts = []
for node in nodes:
result_dict = {"Score": node.score, "Text": node.node.get_text()}
result_dicts.append(result_dict)
pretty_print(pd.DataFrame(result_dicts))
new_nodes = get_retrieved_nodes(
"quel sont les agence qui sont adaptée aux personnes à mobilité réduite",
vector_top_k=10,
reranker_top_n=10,
with_reranker=True,
)
visualize_retrieved_nodes(new_nodes)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
def get_all_text(new_nodes):
"""
This function takes a list of nodes and returns a single string containing the text of each node, joined together.
Args:
new_nodes (list): A list of nodes from which text is to be extracted.
Returns:
str: A single string containing the text from each node, joined together.
"""
texts = []
for node in new_nodes:
texts.append(node.get_text())
return ' '.join(texts)
get_texts = get_all_text(new_nodes)
print(get_texts)
memory = ChatMemoryBuffer.from_defaults(token_limit=6500)
chat_engine = index.as_chat_engine(
llm = OpenAI(temperature=0, model="gpt-4"),
chat_mode="context",
memory=memory,
system_prompt=(
"Assist public agents in providing responses to the residents and citizens of the metropolis in nice, guiding them to the appropriate services that best address their requests for assistance. This involves equipping public agents with the necessary tools and information to efficiently and effectively direct citizens to the services that can fulfill their needs.answer using french"
),
similarity_top_k=10,
node_postprocessors=[
LLMRerank(
choice_batch_size=5,
top_n=5,
)
],
response_mode="tree_summarize",
)
def process(input, history):
response = chat_engine.stream_chat(input)
output = ""
for token in response.response_gen:
print(token, end="")
output += token
return output
iface = gr.ChatInterface(
fn=process,
title="Métropole Chat_Openai",
description="Provide a question and get a response.",
)
iface.launch()