Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import json | |
import pickle | |
import pprint | |
import textwrap | |
import time | |
from tqdm.autonotebook import tqdm | |
from pinecone import Pinecone, ServerlessSpec | |
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY')) | |
index_name = "prorata-postman-ds-128-v2" | |
index = pc.Index(index_name) | |
from openai import OpenAI | |
openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
def get_embedding(text, model="text-embedding-3-small"): | |
text = text.replace("\n", " ") | |
return openai_client.embeddings.create(input = [text], model=model).data[0].embedding | |
from transformers import AutoTokenizer, AutoModel | |
# Load the tokenizer and the model | |
tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0") | |
model = AutoModel.from_pretrained("colbert-ir/colbertv2.0") | |
with open('colbertv2_pc_128_d.pkl', 'rb') as f: | |
colbertv2_pc_128_d = pickle.load(f) | |
version_notes = colbertv2_pc_128_d['version_notes'] | |
chunkid_to_colbertv2 = colbertv2_pc_128_d['chunkid_to_colbertv2'] | |
import torch | |
# Function to compute MaxSim | |
def maxsim(query_embedding, document_embedding): | |
# Expand dimensions for broadcasting | |
# Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size] | |
# Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size] | |
expanded_query = query_embedding.unsqueeze(2) | |
expanded_doc = document_embedding.unsqueeze(1) | |
# Compute cosine similarity across the embedding dimension | |
sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1) | |
# Take the maximum similarity for each query token (across all document tokens) | |
# sim_matrix shape: [batch_size, query_length, doc_length] | |
max_sim_scores, _ = torch.max(sim_matrix, dim=2) | |
# Average these maximum scores across all query tokens | |
avg_max_sim = torch.mean(max_sim_scores, dim=1) | |
return avg_max_sim | |
def get_matches_reranked(q, k=20): | |
matches = index.query(vector=get_embedding(q), top_k=k, include_metadata=True)['matches'] | |
q_encoding = tokenizer(q, return_tensors='pt') | |
q_embedding = model(**q_encoding).last_hidden_state.mean(dim=1) | |
# Calculate MaxSim scores | |
for match in matches: | |
score = maxsim(q_embedding.unsqueeze(0), chunkid_to_colbertv2[match['id']]) | |
match['colbertv2_score'] = score.item() | |
matches_colbertv2 = sorted(matches, key=lambda x: x['colbertv2_score'], reverse=True) | |
return matches_colbertv2 | |
def filter_matches(matches_colbertv2, score_thr=0.0): | |
matches_colbertv2_f = [] | |
url_to_chunk_l = {} | |
for idx, match in enumerate(matches_colbertv2): | |
if match['colbertv2_score'] > score_thr: | |
_url = match['metadata']['url'] | |
if not _url in url_to_chunk_l: | |
url_to_chunk_l[_url] = [] | |
url_to_chunk_l[_url].append(match) | |
matches_colbertv2_f.append(match) | |
return matches_colbertv2_f | |
style_str = """ | |
<style> | |
.doc-title { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
display: inline-block; | |
font-size: 2em; | |
font-weight: bolder; | |
padding-top: 20px; | |
/* font-style: italic; */ | |
} | |
.doc-url { | |
/* font-family: cursive, sans-serif; */ | |
font-size: 1em; | |
padding-left: 40px; | |
padding-bottom: 10px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-text { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
font-size: 1.5em; | |
padding-left: 40px; | |
padding-bottom: 20px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-title > img { | |
width: 22px; | |
height: 22px; | |
border-radius: 50%; | |
overflow: hidden; | |
background-color: transparent; | |
display: inline-block; | |
vertical-align: middle; | |
} | |
.doc-title > score { | |
font-family: Optima, sans-serif; | |
font-weight: normal; | |
float: right; | |
} | |
</style> | |
""" | |
import gradio as gr | |
from io import StringIO | |
from urllib.parse import urlparse | |
def output_chunks_reranked(msg): | |
matches_colbertv2 = get_matches_reranked(msg, k=20) | |
matches_colbertv2 = filter_matches(matches_colbertv2, score_thr=0.55) | |
_out = StringIO() | |
if not matches_colbertv2: | |
print(style_str, file=_out) | |
print(f"<div>", file=_out) | |
print(f"<div class=\"doc-title\">No sources relevant to this target were found.</div>", file=_out) | |
print(f"</div>", file=_out) | |
return _out.getvalue() | |
for idx, match in enumerate(matches_colbertv2): | |
print(style_str, file=_out) | |
print(f"<div>", file=_out) | |
favicon = f"<img src=\"https://www.google.com/s2/favicons?sz=128&domain={urlparse(match['metadata']['url']).netloc}\"/>" | |
print(f"<div class=\"doc-title\">{favicon}  {match['metadata']['title']}<score>{match['colbertv2_score']:.2f}</score></div>", file=_out) | |
print(f"<div class=\"doc-url\"><a href=\"{match['metadata']['url']}\" target=\"_blank\">{match['metadata']['url']}</a></div>", file=_out) | |
# print(f" (Score: {match['colbertv2_score']:.2f})", file=_out) | |
print(f"<div class=\"doc-text\">{match['metadata']['text']}</div>", file=_out) | |
print(f"</div>", file=_out) | |
return _out.getvalue() | |
with gr.Blocks() as demo: | |
msg = gr.Textbox(label='Target') | |
# results_box = gr.Textbox(label='Matches', lines=30, autoscroll=False) | |
results_box = gr.HTML(label='Matches') | |
msg.submit(output_chunks_reranked, msg, results_box, queue=False) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |