Spaces:
Runtime error
Runtime error
"""Ask a question to the netspresso database.""" | |
import json | |
import sys | |
import argparse | |
from typing import List | |
from langchain.chat_models import ChatOpenAI # for `gpt-3.5-turbo` & `gpt-4` | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.prompts import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.schema import BaseRetriever, Document | |
import gradio as gr | |
from search_online import OnlineSearcher | |
# DEFAULT_QUESTION = "๋ชจ๋ธ ๊ฒฝ๋ํ ๋ฐ ์ต์ ํ์ ๊ด๋ จํ์ฌ Netspresso bot์๊ฒ ๋ฌผ์ด๋ณด์ธ์.\n์๋ฅผ๋ค์ด \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." | |
DEFAULT_QUESTION = "Ask the Netspresso bot about model lightweighting and optimization.\nFor example \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." | |
TEMPERATURE = 0 | |
# manual arguments (FIXME) | |
args = argparse.Namespace | |
args.index_type = "hybrid" | |
args.index = ( | |
"/root/indexes/docs-netspresso-ai/sparse,/root/indexes/docs-netspresso-ai/dense" | |
) | |
if isinstance( | |
args.index, tuple | |
): # black extension automatically convert long str to tuple | |
assert len(args.index) == 1 | |
args.index = args.index[0] | |
args.encoder = "castorini/mdpr-question-nq" | |
args.device = "cuda:0" | |
args.alpha = 0.5 | |
args.normalization = True | |
args.lang_abbr = "en" | |
args.K = 10 | |
# initialize qabot | |
print("initialize NP doc retrieval bot") | |
RETRIEVER = OnlineSearcher(args) | |
class LangChainCustomRetrieverWrapper(BaseRetriever): | |
def __init__(self, args): | |
super().__init__() | |
# self.retriever = RETRIEVER # TODO. should be initialize from args | |
# self.args = args | |
print("Initialize LangChainCustomRetrieverWrapper, TODO: fix minor bug") | |
def get_relevant_documents(self, query: str) -> List[Document]: | |
"""Get texts relevant for a query. | |
Args: | |
query: string to find relevant texts for | |
Returns: | |
List of relevant documents | |
""" | |
print(f"query = {query}") | |
# retrieve | |
# hits = self.retriever.search(query, self.args.K) | |
hits = RETRIEVER.search( | |
query, args.K | |
) # TODO: fix bug that BaseRetriever object cannot have extra field | |
# extract docs | |
results = [ | |
{ | |
"contents": json.loads( | |
# self.retriever.searcher.sparse_searcher.doc(hits[i].docid).raw() # TODO: fix bug that BaseRetriever object cannot have extra field | |
RETRIEVER.searcher.sparse_searcher.doc(hits[i].docid).raw() | |
)["contents"], | |
"docid": hits[i].docid, | |
} | |
for i in range(len(hits)) | |
] | |
# make result list of Document object | |
return [ | |
Document( | |
page_content=result["contents"], metadata={"source": result["docid"]} | |
) | |
for result in results | |
] | |
async def aget_relevant_documents( | |
self, query: str | |
) -> List[Document]: # abstractmethod | |
raise NotImplementedError | |
class RaLM: | |
def __init__(self, args): | |
self.args = args | |
self.initialize_ralm() | |
def initialize_ralm(self): | |
# initialize custom retriever | |
self.retriever = LangChainCustomRetrieverWrapper(self.args) | |
# prompt for RaLM | |
system_template = """Use the following pieces of context to answer the users question. | |
Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources. | |
Always try to generate answer from source. | |
---------------- | |
{summaries}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
prompt = ChatPromptTemplate.from_messages(messages) | |
chain_type_kwargs = {"prompt": prompt} | |
llm = ChatOpenAI(model_name=self.args.model_name, temperature=TEMPERATURE) | |
self.chain = RetrievalQAWithSourcesChain.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True, | |
reduce_k_below_max_tokens=True, | |
chain_type_kwargs=chain_type_kwargs, | |
) | |
def run_chain(self, question, force_korean=False): | |
if force_korean: | |
question = f"{question} ๋ณธ๋ฌธ์ ์ฐธ๊ณ ํด์ ํ๊ธ๋ก ๋๋ตํด์ค" | |
result = self.chain({"question": question}) | |
# postprocess | |
result["answer"] = self.postprocess(result["answer"]) | |
if isinstance(result["sources"], str): | |
result["sources"] = self.postprocess(result["sources"]) | |
result["sources"] = result["sources"].split(", ") | |
result["sources"] = [src.strip() for src in result["sources"]] | |
# print result | |
self.print_result(result) | |
return result | |
def print_result( | |
self, result | |
): # print result of RetrievalQAWithSourcesChain of langchain | |
print(f"Answer: {result['answer']}") | |
print(f"Sources: ") | |
print(result["sources"]) | |
assert isinstance(result["sources"], list) | |
nSource = len(result["sources"]) | |
for i in range(nSource): | |
source_title = result["sources"][i] | |
print(f"{source_title}: ") | |
if "source_documents" in result: | |
for j in range(len(result["source_documents"])): | |
if result["source_documents"][j].metadata["source"] == source_title: | |
print(result["source_documents"][j].page_content) | |
break | |
def postprocess(self, text): | |
# remove final parenthesis (bug with unknown cause) | |
if ( | |
text.endswith(")") | |
or text.endswith("(") | |
or text.endswith("[") | |
or text.endswith("]") | |
): | |
text = text[:-1] | |
return text.strip() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Ask a question to the netspresso docs." | |
) | |
# General | |
# parser.add_argument( | |
# "--question", | |
# type=str, | |
# default=None, | |
# required=True, | |
# help="The question to ask for database", | |
# ) | |
parser.add_argument( | |
"--model_name", | |
type=str, | |
default="gpt-3.5-turbo-16k-0613", | |
help="model name for openai api", | |
) | |
# Retriever: fixed arg for now | |
""" | |
parser.add_argument( | |
"--query_encoder_name_or_dir", | |
type=str, | |
default="princeton-nlp/densephrases-multi-query-multi", | |
help="query encoder name registered in huggingface model hub OR custom query encoder checkpoint directory", | |
) | |
parser.add_argument( | |
"--index_name", | |
type=str, | |
default="1048576_flat_OPQ96", | |
help="index name appended to index directory prefix", | |
) | |
""" | |
args = parser.parse_args() | |
# to prevent collision with DensePhrase native argparser | |
sys.argv = [sys.argv[0]] | |
# initialize class | |
app = RaLM(args) | |
def question_answer(question): | |
result = app.run_chain(question=question, force_korean=False) | |
return result[ | |
"answer" | |
], "\n######################################################\n\n".join( | |
[ | |
f"Source {idx}\n{doc.page_content}" | |
for idx, doc in enumerate(result["source_documents"]) | |
] | |
) | |
# launch gradio | |
gr.Interface( | |
fn=question_answer, | |
inputs=gr.inputs.Textbox(default=DEFAULT_QUESTION, label="Question"), | |
outputs=[ | |
gr.inputs.Textbox(default="", label="Bot response"), | |
gr.inputs.Textbox(default="", label="Search result used by bot"), | |
], | |
title="Netspresso Q&A bot", | |
theme="dark-grass", | |
description="Ask the Netspresso bot about model lightweighting and optimization.", # simplified version, hide detail version | |
# description="๋ชจ๋ธ ๊ฒฝ๋ํ ๋ฐ ์ต์ ํ์ ๊ด๋ จํ์ฌ Netspresso bot์๊ฒ ๋ฌผ์ด๋ณด์ธ์.\n\n retriever: BM25&mdpr-question-nq, generator: gpt-3.5-turbo-16k-0613 (API)", | |
).launch(share=True) | |