Csplk's picture
Update app.py
b29ddc0 verified
import gradio as gr
import spaces
import subprocess
import os
import shutil
import string
import random
import glob
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m")
chunk_size = int(os.environ.get("CHUNK_SIZE", 128))
default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258))
model = SentenceTransformer(model_name)
model.to(device="cuda")
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]:
query_embeddings = model.encode(queries, prompt_name="query")
document_embeddings = model.encode(chunks)
scores = query_embeddings @ document_embeddings.T
results = {}
for query, query_scores in zip(queries, scores):
chunk_idxs = [i for i in range(len(chunks))]
# Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]}
results[query] = list(zip(chunk_idxs, query_scores))
return results
def extract_text_from_pdf(reader):
full_text = ""
for idx, page in enumerate(reader.pages):
text = page.extract_text()
if len(text) > 0:
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n"
return full_text.strip()
def convert(filename) -> str:
plain_text_filetypes = [
".txt",
".csv",
".tsv",
".md",
".yaml",
".toml",
".json",
".json5",
".jsonc",
]
# Already a plain text file that wouldn't benefit from pandoc so return the content
if any(filename.endswith(ft) for ft in plain_text_filetypes):
with open(filename, "r") as f:
return f.read()
if filename.endswith(".pdf"):
return extract_text_from_pdf(PdfReader(filename))
raise ValueError(f"Unsupported file type: {filename}")
def chunk_to_length(text, max_length=512):
chunks = []
while len(text) > max_length:
chunks.append(text[:max_length])
text = text[max_length:]
chunks.append(text)
return chunks
@spaces.GPU
def predict(query, max_characters) -> str:
# Embed the query
query_embedding = model.encode(query, prompt_name="query")
# Initialize a list to store all chunks and their similarities across all documents
all_chunks = []
# Iterate through all documents
for filename, doc in docs.items():
# Calculate dot product between query and document embeddings
similarities = doc["embeddings"] @ query_embedding.T
# Add chunks and similarities to the all_chunks list
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])
# Sort all chunks by similarity
all_chunks.sort(key=lambda x: x[2], reverse=True)
# Initialize a dictionary to store relevant chunks for each document
relevant_chunks = {}
# Add most relevant chunks until max_characters is reached
total_chars = 0
for filename, chunk, _ in all_chunks:
if total_chars + len(chunk) <= max_characters:
if filename not in relevant_chunks:
relevant_chunks[filename] = []
relevant_chunks[filename].append(chunk)
total_chars += len(chunk)
else:
break
return relevant_chunks
docs = {}
for filename in glob.glob("sources/*"):
if filename.endswith("add_your_files_here"):
continue
converted_doc = convert(filename)
chunks = chunk_to_length(converted_doc, chunk_size)
embeddings = model.encode(chunks)
docs[filename] = {
"chunks": chunks,
"embeddings": embeddings,
}
gr.Interface(
predict,
inputs=[
gr.Textbox(label="Query asked about the documents"),
gr.Number(label="Max output characters", value=default_max_characters),
],
outputs=[gr.JSON(label="Relevant chunks")],
title="Gradio Docs",
description="This is a gradio docs rag tool for use in hf chat tools",
).launch()