mmt_retrieval / app.py
whoami02's picture
Update app.py
261e71b verified
raw
history blame
2.08 kB
import gradio as gr
import re
import numpy as np
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import os
HUGGINGFACEHUB_API_TOKEN = os.environ["token"]
def clean_(s):
s = s.replace("\n0: ", "=")
return re.split('=', s, maxsplit=1)[-1].strip()
def similarity_search2(vectordb, query, k, unique):
print(f"\nQuery Key: {query}, \nrows requested:{k}\nUnique values:{unique}")
if unique == "False":
vals = vectordb.similarity_search(query,k=k)
else:
vals = vectordb.similarity_search(query,k=1)
temp = []
for val in vals:
temp.append(clean_(val.page_content))
return str(np.array(temp))[1:-1]
with gr.Blocks() as demo:
gr.Markdown(
"""
<h2> <center> Query Retrieval </center> </h2>
""")
with gr.Row():
with gr.Column():
query = gr.Textbox(placeholder="your query", label="Query")
k = gr.Slider(1,306,1, label="number of samples to check")
unique = gr.Radio(["True", "False"], label="Return Unique values")
with gr.Row():
btn = gr.Button("Submit")
def mmt_query(query, k, unique):
model_id = "BAAI/bge-large-en-v1.5"
model_kwargs = {"device": "cpu"}
embedding = HuggingFaceBgeEmbeddings(
model_name = model_id,
model_kwargs = model_kwargs,
cache_folder=r"models",
encode_kwargs = {'normalize_embeddings':True},
)
persist_directory = "MMT_unique"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
return similarity_search2(vectordb, query, k, unique)
with gr.Column():
output = gr.Textbox(scale=10, label="Output")
btn.click(mmt_query, [query, k, unique], output)
# interface = gr.Interface(fn=auto_eda, inputs="dataframe", outputs="json")
# demo.queue()
demo.launch(share=True)