zotero-search / app.py
rbiswasfc's picture
app
7324658
import os
from typing import ClassVar
# import dotenv
import gradio as gr
import lancedb
import srsly
from huggingface_hub import snapshot_download
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, ColbertReranker
from lancedb.util import attempt_import_or_raise
# dotenv.load_dotenv()
@register("coherev3")
class CohereEmbeddingFunction_2(TextEmbeddingFunction):
name: str = "embed-english-v3.0"
client: ClassVar = None
def ndims(self):
return 768
def generate_embeddings(self, texts):
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction_2.client.embed(
texts=texts, model=self.name, input_type="search_document"
)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction_2.client is None:
CohereEmbeddingFunction_2.client = cohere.Client(
os.environ["COHERE_API_KEY"]
)
COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()
class ArxivModel(LanceModel):
text: str = COHERE_EMBEDDER.SourceField()
vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
title: str
paper_title: str
content_type: str
arxiv_id: str
def download_data():
snapshot_download(
repo_id="rbiswasfc/zotero_db",
repo_type="dataset",
local_dir="./data",
token=os.environ["HF_TOKEN"],
)
print("Data downloaded!")
download_data()
VERSION = "0.0.0a"
DB = lancedb.connect("./data/.lancedb_zotero_v0")
ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
TBL = DB.open_table("arxiv_zotero_v0")
def _format_results(arxiv_refs):
results = []
for arx_id, paper_title in arxiv_refs.items():
abstract = ID_TO_ABSTRACT.get(arx_id, "")
# these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
if "Abstract\n\n" in abstract:
abstract = abstract.split("Abstract\n\n")[-1]
if paper_title in abstract:
abstract = abstract.split(paper_title)[-1]
if abstract.startswith("\n"):
abstract = abstract[1:]
if "\n\n" in abstract[:20]:
abstract = "\n\n".join(abstract.split("\n\n")[1:])
result = {
"title": paper_title,
"url": f"https://arxiv.org/abs/{arx_id}",
"abstract": abstract,
}
results.append(result)
return results
def query_db(query: str, k: int = 10, reranker: str = "cohere"):
raw_results = TBL.search(query, query_type="hybrid").limit(k)
if reranker is not None:
ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
else:
ranked_results = raw_results
ranked_results = ranked_results.to_pandas()
top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
3
)
top_results_dict = {
row["arxiv_id"]: row["paper_title"]
for index, row in ranked_results.iterrows()
if row["arxiv_id"] in top_results.index
}
final_results = _format_results(top_results_dict)
return final_results
with gr.Blocks() as demo:
with gr.Row():
query = gr.Textbox(label="Query", placeholder="Enter your query...")
submit_btn = gr.Button("Submit")
output = gr.JSON(label="Search Results")
# # callback ---
submit_btn.click(
fn=query_db,
inputs=query,
outputs=output,
)
demo.launch()