|
import time |
|
import json |
|
from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher |
|
import streamlit as st |
|
from pathlib import Path |
|
import sys |
|
path_root = Path("./") |
|
sys.path.append(str(path_root)) |
|
|
|
|
|
encoder_index_map = { |
|
'uniCOIL': ('UniCoil', 'castorini/unicoil-noexp-msmarco-passage', 'index-unicoil'), |
|
'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'naver/splade-cocondenser-ensembledistil', 'index-splade-pp-ed'), |
|
'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'naver/splade-cocondenser-selfdistil', 'index-splade-pp-sd') |
|
} |
|
|
|
index = 'index-splade-pp-ed' |
|
encoder = 'SpladePlusPlusEnsembleDistil' |
|
encoder_index = 0 |
|
|
|
st.set_page_config(page_title="Pyserini with ONNX Runtime", |
|
page_icon='πΈ', layout="centered") |
|
|
|
cola, colb, colc = st.columns([5, 4, 5]) |
|
with colb: |
|
st.image("logo.jpeg") |
|
|
|
|
|
colaa, colbb, colcc = st.columns([1, 8, 1]) |
|
with colbb: |
|
runtime = st.select_slider( |
|
'Select a runtime type', |
|
options=['PyTorch', 'ONNX Runtime']) |
|
st.write('Now using: ', runtime) |
|
|
|
|
|
colaa, colbb, colcc = st.columns([1, 8, 1]) |
|
with colbb: |
|
encoder = st.select_slider( |
|
'Select a query encoder', |
|
options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil']) |
|
st.write('Now Running Encoder: ', encoder) |
|
|
|
if runtime == 'PyTorch': |
|
runtime = 'pytorch' |
|
runtime_index = 1 |
|
else: |
|
runtime = 'onnx' |
|
runtime_index = 0 |
|
|
|
encoder, index = encoder_index_map[encoder][runtime_index], encoder_index_map[encoder][2] |
|
|
|
searcher = LuceneImpactSearcher( |
|
f'indexes/{index}', f'{encoder}', encoder_type=f'{runtime}') |
|
|
|
corpus = LuceneSearcher(f'indexes/index-unicoil') |
|
|
|
col1, col2 = st.columns([9, 1]) |
|
with col1: |
|
search_query = st.text_input(label="search query", placeholder="Search") |
|
|
|
with col2: |
|
st.write('#') |
|
button_clicked = st.button("π") |
|
|
|
|
|
if search_query or button_clicked: |
|
num_results = None |
|
t_0 = time.time() |
|
search_results = searcher.search(search_query, k=10) |
|
search_time = time.time() - t_0 |
|
st.write( |
|
f'<p align=\"right\" style=\"color:grey;\">Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms</p>', unsafe_allow_html=True) |
|
for i, result in enumerate(search_results[:10]): |
|
result_score = result.score |
|
result_id = result.docid |
|
contents = json.loads(result.raw) |
|
contents = contents['contents'] if 'contents' in contents else contents['content'] |
|
if contents == "": |
|
contents = json.loads(corpus.doc(result_id).raw())['contents'] |
|
|
|
output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>' |
|
|
|
try: |
|
st.write(output, unsafe_allow_html=True) |
|
st.write( |
|
f'<div class="row">{contents}</div>', unsafe_allow_html=True) |
|
|
|
except: |
|
pass |
|
st.write('---') |
|
|