BioMedIA / save_faiss_index.py
Alejandro Vaca
initial commit
6bf4ad7
from datasets import load_dataset
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder
from general_utils import embed_passages, embed_passages_haystack
import faiss
import argparse
import os
from haystack.nodes import DensePassageRetriever
from haystack.document_stores import InMemoryDocumentStore
os.environ["OMP_NUM_THREADS"] = "8"
def create_faiss_index(args):
minchars = 200
dims = 128
dpr = DensePassageRetriever(
document_store=InMemoryDocumentStore(),
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
max_seq_len_query=64,
max_seq_len_passage=256,
batch_size=512,
)
dataset = load_dataset(
"IIC/spanish_biomedical_crawled_corpus", split="train"
)
dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
def embed_passages_retrieval(examples):
return embed_passages_haystack(dpr, examples)
dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192)
dataset.add_faiss_index(
column="embeddings",
string_factory="OPQ64_128,IVF4898,PQ64x4fsr",
train_size=len(dataset),
)
dataset.save_faiss_index("embeddings", args.index_file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
parser.add_argument(
"--ctx_encoder_name",
default="IIC/dpr-spanish-passage_encoder-squades-base",
help="Encoding model to use for passage encoding",
)
parser.add_argument(
"--index_file_name",
default="dpr_index_bio_splitted.faiss",
help="Faiss index file with passage embeddings",
)
parser.add_argument(
"--device", default="cuda:0", help="The device to index data on."
)
main_args, _ = parser.parse_known_args()
create_faiss_index(main_args)