Federico Galatolo
spelling
a09dcd0
raw
history blame
3.87 kB
import os
import streamlit as st
from elasticsearch import Elasticsearch
from embedders.labse import LaBSE
def search():
status_indicator.write(f"Loading model {model_name} (it can take ~1 minute the first time)...")
model = globals()[model_name]()
status_indicator.write(f"Computing query embeddings...")
query_vector = model(query)[0, :].tolist()
status_indicator.write(f"Performing query...")
target_field = f"{model_name}_features"
results = es.search(
index="sentences",
query={
"script_score": {
"query": {"match_all": {}},
"script": {
"source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0",
"params": {"query_vector": query_vector}
}
}
},
size=limit
)
for result in results["hits"]["hits"]:
sentence = result['_source']['sentence']
score = result['_score']
document = result['_source']['document']
number = result['_source']['number']
previous = es.search(
index="sentences",
query={
"bool": {
"must": [{
"term": {
"document": document
}
},{
"range": {
"number": {
"gte": number-3,
"lt": number,
}
}
}
]
}
}
)
previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"])
previous_context = "".join([r["_source"]["sentence"] for r in previous_hits])
subsequent = es.search(
index="sentences",
query={
"bool": {
"must": [{
"term": {
"document": document
}
},{
"range": {
"number": {
"lte": number+3,
"gt": number,
}
}
}
]
}
}
)
subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"])
subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits])
document_name_results = es.search(
index="documents",
query={
"bool": {
"must": [{
"term": {
"id": document
}
}
]
}
}
)
document_name_data = document_name_results["hits"]["hits"][0]["_source"]
document_name = f"{document_name_data['title']} - {document_name_data['author']}"
results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}")
status_indicator.write(f"Results ready...")
es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":"))
st.header("Serica Semantic Search")
st.write("Perform a semantic search using a Sentence Embedding Transformer model on the SERICA database")
model_name = st.selectbox("Model", ["LaBSE"])
limit = st.number_input("Number of results", 10)
query = st.text_input("Query", value="")
status_indicator = st.empty()
do_search = st.button("Search")
results_placeholder = st.container()
if do_search:
search()