Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from elasticsearch import Elasticsearch | |
from embedders.labse import LaBSE | |
def search(): | |
status_indicator.write(f"Loading model {model_name} (it can take ~1 minute the first time)...") | |
model = globals()[model_name]() | |
status_indicator.write(f"Computing query embeddings...") | |
query_vector = model(query)[0, :].tolist() | |
status_indicator.write(f"Performing query...") | |
target_field = f"{model_name}_features" | |
results = es.search( | |
index="sentences", | |
query={ | |
"script_score": { | |
"query": {"match_all": {}}, | |
"script": { | |
"source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0", | |
"params": {"query_vector": query_vector} | |
} | |
} | |
}, | |
size=limit | |
) | |
for result in results["hits"]["hits"]: | |
sentence = result['_source']['sentence'] | |
score = result['_score'] | |
document = result['_source']['document'] | |
number = result['_source']['number'] | |
previous = es.search( | |
index="sentences", | |
query={ | |
"bool": { | |
"must": [{ | |
"term": { | |
"document": document | |
} | |
},{ | |
"range": { | |
"number": { | |
"gte": number-3, | |
"lt": number, | |
} | |
} | |
} | |
] | |
} | |
} | |
) | |
previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) | |
previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) | |
subsequent = es.search( | |
index="sentences", | |
query={ | |
"bool": { | |
"must": [{ | |
"term": { | |
"document": document | |
} | |
},{ | |
"range": { | |
"number": { | |
"lte": number+3, | |
"gt": number, | |
} | |
} | |
} | |
] | |
} | |
} | |
) | |
subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"]) | |
subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) | |
document_name_results = es.search( | |
index="documents", | |
query={ | |
"bool": { | |
"must": [{ | |
"term": { | |
"id": document | |
} | |
} | |
] | |
} | |
} | |
) | |
document_name_data = document_name_results["hits"]["hits"][0]["_source"] | |
document_name = f"{document_name_data['title']} - {document_name_data['author']}" | |
results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}") | |
status_indicator.write(f"Results ready...") | |
es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":")) | |
st.header("Serica Semantic Search") | |
st.write("Perform a semantic search using a Sentence Embedding Transformer model on the SERICA database") | |
model_name = st.selectbox("Model", ["LaBSE"]) | |
limit = st.number_input("Number of results", 10) | |
query = st.text_input("Query", value="") | |
status_indicator = st.empty() | |
do_search = st.button("Search") | |
results_placeholder = st.container() | |
if do_search: | |
search() |