Spaces:
Running
on
Zero
Running
on
Zero
import tqdm | |
from PIL import Image | |
import hashlib | |
import torch | |
import torch.nn.functional as F | |
import fitz | |
import threading | |
import gradio as gr | |
import spaces | |
import os | |
from transformers import AutoModel | |
from transformers import AutoTokenizer | |
from PIL import Image | |
import torch | |
import os | |
import numpy as np | |
import json | |
import time | |
cache_dir = '/data/KB' | |
os.makedirs(cache_dir, exist_ok=True) | |
def weighted_mean_pooling(hidden, attention_mask): | |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) | |
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1) | |
d = attention_mask_.sum(dim=1, keepdim=True).float() | |
reps = s / d | |
return reps | |
def encode(text_or_image_list): | |
global model, tokenizer | |
if (isinstance(text_or_image_list[0], str)): | |
inputs = { | |
"text": text_or_image_list, | |
'image': [None] * len(text_or_image_list), | |
'tokenizer': tokenizer | |
} | |
else: | |
inputs = { | |
"text": [''] * len(text_or_image_list), | |
'image': text_or_image_list, | |
'tokenizer': tokenizer | |
} | |
outputs = model(**inputs) | |
attention_mask = outputs.attention_mask | |
hidden = outputs.last_hidden_state | |
reps = weighted_mean_pooling(hidden, attention_mask) | |
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy() | |
return embeddings | |
def add_pdf_gradio(pdf_file_list, progress=gr.Progress()): | |
global model, tokenizer | |
model.eval() | |
print(pdf_file_list) | |
pdf_file_list = sorted(pdf_file_list, key=lambda x: os.path.basename(x)) | |
print(pdf_file_list) | |
knowledge_base_name = str(int(time.time())) | |
this_cache_dir = os.path.join(cache_dir, knowledge_base_name) | |
os.makedirs(this_cache_dir, exist_ok=True) | |
index2img_filename = [] | |
for pdf_file_path in pdf_file_list: | |
with open(os.path.join(this_cache_dir, os.path.basename(pdf_file_path)), 'wb') as file1: | |
with open(pdf_file_path, "rb") as file2: | |
file1.write(file2.read()) | |
for pdf_file_path in pdf_file_list: | |
print(f"Processing {pdf_file_path}") | |
pdf_name = os.path.basename(pdf_file_path) | |
dpi = 200 | |
doc = fitz.open(pdf_file_path) | |
reps_list = [] | |
images = [] | |
for page in progress.tqdm(doc): | |
# with self.lock: # because we hope one 16G gpu only process one image at the same time | |
pix = page.get_pixmap(dpi=dpi) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
with torch.no_grad(): | |
reps = encode([image]) | |
reps_list.append(reps) | |
images.append(image) | |
for idx in range(len(images)): | |
image = images[idx] | |
cache_image_path = os.path.join(this_cache_dir, f"{pdf_name}_{idx}.png") | |
image.save(cache_image_path) | |
index2img_filename.append(os.path.basename(cache_image_path)) | |
np.save(os.path.join(this_cache_dir, f"{pdf_name.split('.')[0]}.npy"), reps_list) | |
with open(os.path.join(this_cache_dir, f"index2img_filename.txt"), 'w') as f: | |
f.write('\n'.join(index2img_filename)) | |
return knowledge_base_name | |
def retrieve_gradio(knowledge_base: str, query: str, topk: int): | |
global model, tokenizer | |
model.eval() | |
target_cache_dir = os.path.join(cache_dir, knowledge_base) | |
if not os.path.exists(target_cache_dir): | |
return None | |
with open(os.path.join(target_cache_dir, f"index2img_filename.txt"), 'r') as f: | |
index2img_filename = f.read().split('\n') | |
doc_list = [f for f in os.listdir(target_cache_dir) if f.endswith('.npy')] | |
doc_list = sorted(doc_list) | |
doc_reps = [np.load(os.path.join(target_cache_dir, f)) for f in doc_list] | |
doc_reps_cat = torch.cat([torch.Tensor(i) for i in doc_reps], dim=0) | |
doc_reps_cat = torch.cat([i for i in doc_reps_cat], dim=0) | |
query_with_instruction = "Represent this query for retrieving relevant document: " + query | |
with torch.no_grad(): | |
query_rep = torch.Tensor(encode([query_with_instruction])) | |
query_md5 = hashlib.md5(query.encode()).hexdigest() | |
print(f"query_rep_shape: {query_rep.shape}, doc_reps_cat_shape: {doc_reps_cat.shape}") | |
similarities = torch.matmul(query_rep, doc_reps_cat.T) | |
topk_values, topk_doc_ids = torch.topk(similarities, k=topk) | |
topk_values_np = topk_values.squeeze(0).cpu().numpy() | |
topk_doc_ids_np = topk_doc_ids.squeeze(0).cpu().numpy() | |
similarities_np = similarities.cpu().numpy() | |
print(f"topk_doc_ids_np: {topk_doc_ids_np}, topk_values_np: {topk_values_np}") | |
images_topk = [Image.open(os.path.join(target_cache_dir, f"{index2img_filename[idx]}")) for idx in topk_doc_ids_np] | |
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f: | |
f.write(json.dumps( | |
{ | |
"knowledge_base": knowledge_base, | |
"query": query, | |
"retrived_docs": [os.path.join(target_cache_dir, f"{index2img_filename[idx]}") for idx in topk_doc_ids_np] | |
}, indent=4, ensure_ascii=False | |
)) | |
return images_topk | |
def upvote(knowledge_base, query): | |
global model, tokenizer | |
target_cache_dir = os.path.join(cache_dir, knowledge_base) | |
query_md5 = hashlib.md5(query.encode()).hexdigest() | |
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f: | |
data = json.loads(f.read()) | |
data["user_preference"] = "upvote" | |
with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f: | |
f.write(json.dumps(data, indent=4, ensure_ascii=False)) | |
print("up", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json")) | |
gr.Info('Received! Thank you very much!') | |
return | |
def downvote(knowledge_base, query): | |
global model, tokenizer | |
target_cache_dir = os.path.join(cache_dir, knowledge_base) | |
query_md5 = hashlib.md5(query.encode()).hexdigest() | |
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f: | |
data = json.loads(f.read()) | |
data["user_preference"] = "downvote" | |
with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f: | |
f.write(json.dumps(data, indent=4, ensure_ascii=False)) | |
print("down", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json")) | |
gr.Info('Received! Thank you very much!') | |
return | |
device = 'cuda' | |
print("emb model load begin...") | |
model_path = 'openbmb/VisRAG-Ret' # replace with your local model path | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True) | |
model.eval() | |
model.to(device) | |
print("emb model load success!") | |
print("gen model load begin...") | |
gen_model_path = 'openbmb/MiniCPM-V-2_6' | |
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_path, attn_implementation='sdpa', trust_remote_code=True) | |
gen_model = AutoModel.from_pretrained(gen_model_path, trust_remote_code=True, | |
attn_implementation='sdpa', torch_dtype=torch.bfloat16) | |
gen_model.eval() | |
gen_model.to(device) | |
print("gen model load success!") | |
def answer_question(images, question): | |
global gen_model, gen_tokenizer | |
# here each element of images is a tuple of (image_path, None). | |
images_ = [Image.open(image[0]).convert('RGB') for image in images] | |
msgs = [{'role': 'user', 'content': [question, *images_]}] | |
answer = gen_model.chat( | |
image=None, | |
msgs=msgs, | |
tokenizer=gen_tokenizer | |
) | |
print(answer) | |
return answer | |
with gr.Blocks() as app: | |
gr.Markdown("# VisRAG Pipeline: Vision-based Retrieval-augmented Generation on Multi-modality Documents") | |
gr.Markdown(""" | |
- A Vision Language Model Dense Retriever ([VisRAG-Ret](https://huggingface.co/openbmb/VisRAG-Ret)) **directly reads** your PDFs **without need for OCR**, generates **multimodal dense representations** and assists in building your personal library. | |
- **Ask a question**, and it will retrieve the most relevant pages. Then, [MiniCPM-V-2.6](https://huggingface.co/spaces/openbmb/MiniCPM-V-2_6) will answer your question based on the recalled pages, utilizing its strong multi-image understanding capabilities. | |
- It assists you in reading **lengthy**, **visually-intensive** or **text-oriented** PDF documents, helping you locate pages that answer your questions. | |
- It enables you to build a personal library and retrieve book pages from a large collection of books. | |
- It works like a human: reading, storing, retrieving, and answering with full visual comprehension. | |
""") | |
gr.Markdown("- The current online demo supports PDF documents with fewer than 50 pages due to GPU time limitations. For longer PDFs and books, consider deploying it on your own machine.") | |
gr.Markdown("Thank you very much to [@bokesyo](https://huggingface.co/bokesyo) for writing the code.") | |
with gr.Row(): | |
file_input = gr.File(file_types=["pdf"], file_count="multiple", label="Step 1: Upload PDF") | |
file_result = gr.Text(label="Knowledge Base ID (remember it, it is re-usable!)") | |
process_button = gr.Button("Process PDF (Don't click until PDF uploaded successfully)") | |
process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result) | |
with gr.Row(): | |
kb_id_input = gr.Text(label="Your Knowledge Base ID (paste your Knowledge Base ID here, it is re-usable):") | |
query_input = gr.Text(label="Your Queston") | |
topk_input = inputs=gr.Number(value=1, minimum=1, maximum=10, step=1, label="Number of pages to retrieve") | |
retrieve_button = gr.Button("Step2: Retrieve Pages") | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[["car_owner_manual.pdf"], "1731341207", "怀孕如何系安全带?"], | |
[["car_owner_manual.pdf"], "1731341207", "什么时候会触发侧气囊弹出?"], | |
[["car_owner_manual.pdf"], "1731341207", "How to wear seat belts when pregnant?"], | |
[["car_owner_manual.pdf"], "1731341207", "When will the side airbags be deployed?"], | |
[["main_figure.pdf"], "1731342441", "What is VisRAG?"], | |
[["main_figure.pdf"], "1731342441", "How does VisRAG perform?"] | |
], | |
inputs=[file_input, kb_id_input, query_input], | |
) | |
with gr.Row(): | |
images_output = gr.Gallery(label="Retrieved Pages") | |
retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output) | |
with gr.Row(): | |
button = gr.Button("Step 3: Answer Question with Retrieved Pages") | |
gen_model_response = gr.Textbox(label="MiniCPM-V-2.6's Answer") | |
button.click(fn=answer_question, inputs=[images_output, query_input], outputs=gen_model_response) | |
with gr.Row(): | |
downvote_button = gr.Button("🤣Downvote") | |
upvote_button = gr.Button("🤗Upvote") | |
upvote_button.click(upvote, inputs=[kb_id_input, query_input], outputs=None) | |
downvote_button.click(downvote, inputs=[kb_id_input, query_input], outputs=None) | |
gr.Markdown("By using this demo, you agree to share your usage data with us for research purposes, helping us improve the user experience.") | |
app.launch() |