ccm commited on
Commit
bc256ab
β€’
1 Parent(s): 3c26677

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -5
main.py CHANGED
@@ -8,9 +8,6 @@ import pandas # Needed for operating on dataset
8
  import sentence_transformers # Needed for query embedding
9
  import faiss # Needed for fast similarity search
10
 
11
- # Load the model for later use in embeddings
12
- model = sentence_transformers.SentenceTransformer("allenai-specter")
13
-
14
  # Load the dataset and convert to pandas
15
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
16
 
@@ -28,12 +25,18 @@ data.reset_index(inplace=True)
28
  # Create a FAISS index for fast similarity search
29
  index = faiss.IndexFlatL2(len(data["embedding"][0]))
30
  index.metric_type = faiss.METRIC_INNER_PRODUCT
31
- index.add(faiss.normalize_L2(numpy.stack(data["embedding"].tolist(), axis=0)))
 
 
 
 
 
32
 
33
 
34
  # Define the search function
35
  def search(query: str, k: int):
36
- query = numpy.expand_dims(faiss.normalize_L2(model.encode(query)), axis=0)
 
37
  D, I = index.search(query, k)
38
  top_five = data.loc[I[0]]
39
  search_results = ""
 
8
  import sentence_transformers # Needed for query embedding
9
  import faiss # Needed for fast similarity search
10
 
 
 
 
11
  # Load the dataset and convert to pandas
12
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
13
 
 
25
  # Create a FAISS index for fast similarity search
26
  index = faiss.IndexFlatL2(len(data["embedding"][0]))
27
  index.metric_type = faiss.METRIC_INNER_PRODUCT
28
+ vectors = numpy.stack(data["embedding"].tolist(), axis=0)
29
+ faiss.normalize_L2(vectors)
30
+ index.add(vectors)
31
+
32
+ # Load the model for later use in embeddings
33
+ model = sentence_transformers.SentenceTransformer("allenai-specter")
34
 
35
 
36
  # Define the search function
37
  def search(query: str, k: int):
38
+ query = numpy.expand_dims(model.encode(query), axis=0)
39
+ faiss.normalize_L2(query)
40
  D, I = index.search(query, k)
41
  top_five = data.loc[I[0]]
42
  search_results = ""