Pringled commited on
Commit
4f0286f
1 Parent(s): adde4af
Files changed (1) hide show
  1. app.py +111 -141
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
- import asyncio
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -30,54 +30,7 @@ def display_word_differences(x: str, y: str) -> str:
30
  diff = ndiff(x.split(), y.split())
31
  return " ".join([word for word in diff if word.startswith(('+', '-'))])
32
 
33
- async def compute_embeddings_async(texts, batch_size, progress, desc):
34
- embeddings = []
35
- total_batches = (len(texts) + batch_size - 1) // batch_size
36
- for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
37
- batch_embeddings = await asyncio.to_thread(model.encode, batch_texts, show_progressbar=False)
38
- embeddings.append(batch_embeddings)
39
- progress((i + 1) / total_batches, desc=desc)
40
- await asyncio.sleep(0)
41
- embedding_matrix = np.concatenate(embeddings, axis=0)
42
- return embedding_matrix
43
-
44
- async def deduplicate_async(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
45
- """
46
- Deduplicate embeddings asynchronously.
47
- """
48
- progress(0, desc="Building search index...")
49
- reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
50
-
51
- deduplicated_indices = set(range(len(embedding_matrix)))
52
- duplicate_to_original_mapping = {}
53
-
54
- progress(0, desc="Finding nearest neighbors...")
55
- results = await asyncio.to_thread(reach.nearest_neighbor_threshold,
56
- embedding_matrix,
57
- threshold=threshold,
58
- batch_size=batch_size,
59
- show_progressbar=False)
60
-
61
- total_items = len(embedding_matrix)
62
- for i, similar_items in enumerate(results):
63
- if i not in deduplicated_indices:
64
- continue
65
-
66
- similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
67
-
68
- for sim_idx in similar_indices:
69
- if sim_idx in deduplicated_indices:
70
- deduplicated_indices.remove(sim_idx)
71
- duplicate_to_original_mapping[sim_idx] = i
72
-
73
- if i % 100 == 0:
74
- progress(i / total_items, desc="Processing duplicates")
75
- await asyncio.sleep(0)
76
-
77
- progress(1, desc="Processing duplicates")
78
- return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
79
-
80
- async def perform_deduplication(
81
  deduplication_type,
82
  dataset1_name,
83
  dataset1_split,
@@ -112,12 +65,26 @@ async def perform_deduplication(
112
  # Compute embeddings
113
  status = "Computing embeddings for Dataset 1..."
114
  yield status, ""
115
- embedding_matrix = await compute_embeddings_async(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # Deduplicate
118
  status = "Deduplicating embeddings..."
119
  yield status, ""
120
- deduplicated_indices, duplicate_to_original_mapping = await deduplicate_async(
121
  embedding_matrix, threshold, progress=progress
122
  )
123
 
@@ -150,106 +117,109 @@ async def perform_deduplication(
150
  yield status, result_text
151
 
152
  elif deduplication_type == "Cross-dataset":
153
- # Similar code for cross-dataset deduplication, using async functions
154
- # Load Dataset 1
155
- status = "Loading Dataset 1..."
156
- yield status, ""
157
- if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
158
- ds1 = ds_default1
159
- else:
160
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
161
-
162
- # Load Dataset 2
163
- status = "Loading Dataset 2..."
164
- yield status, ""
165
- if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
166
- ds2 = ds_default2
167
- else:
168
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
169
-
170
- # Extract texts from Dataset 1
171
- status = "Extracting texts from Dataset 1..."
172
- yield status, ""
173
- texts1 = [example[dataset1_text_column] for example in ds1]
174
-
175
- # Extract texts from Dataset 2
176
- status = "Extracting texts from Dataset 2..."
177
- yield status, ""
178
- texts2 = [example[dataset2_text_column] for example in ds2]
179
-
180
- # Compute embeddings for Dataset 1
181
- status = "Computing embeddings for Dataset 1..."
182
- yield status, ""
183
- embedding_matrix1 = await compute_embeddings_async(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
184
-
185
- # Compute embeddings for Dataset 2
186
- status = "Computing embeddings for Dataset 2..."
187
- yield status, ""
188
- embedding_matrix2 = await compute_embeddings_async(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
189
-
190
- # Deduplicate across datasets
191
- status = "Deduplicating embeddings across datasets..."
192
- yield status, ""
193
- duplicate_indices_in_ds2, duplicate_to_original_mapping = await deduplicate_across_datasets_async(
194
- embedding_matrix1, embedding_matrix2, threshold, progress=progress
195
- )
196
-
197
- num_duplicates = len(duplicate_indices_in_ds2)
198
- num_total_ds2 = len(texts2)
199
- num_unique_ds2 = num_total_ds2 - num_duplicates
200
-
201
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
202
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
203
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
204
-
205
- # Show deduplicated examples
206
- if num_duplicates > 0:
207
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
208
- num_examples = min(5, num_duplicates)
209
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
210
- original_idx = duplicate_to_original_mapping[duplicate_idx]
211
- original_text = texts1[original_idx]
212
- duplicate_text = texts2[duplicate_idx]
213
- differences = display_word_differences(original_text, duplicate_text)
214
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
215
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
216
- result_text += f"**Differences:**\n{differences}\n"
217
- result_text += "-" * 50 + "\n\n"
218
- else:
219
- result_text += "No duplicates found."
220
-
221
- # Final status
222
- status = "Deduplication completed."
223
- yield status, result_text
224
 
225
  except Exception as e:
226
  yield f"An error occurred: {e}", ""
227
  raise e
228
 
229
- async def deduplicate_across_datasets_async(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
230
  """
231
- Deduplicate embeddings across two datasets asynchronously.
232
  """
233
- progress(0, desc="Building search index from Dataset 1...")
234
- reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
 
235
 
236
- duplicate_indices_in_test = []
237
  duplicate_to_original_mapping = {}
238
 
239
- progress(0, desc="Finding nearest neighbors between datasets...")
240
- results = await asyncio.to_thread(reach.nearest_neighbor_threshold,
241
- embedding_matrix_2,
242
- threshold=threshold,
243
- batch_size=batch_size,
244
- show_progressbar=False)
 
 
 
 
 
 
 
 
 
 
245
 
246
- total_items = len(embedding_matrix_2)
247
- for i, similar_items in enumerate(results):
248
- similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
 
249
 
250
- if similar_indices:
251
- duplicate_indices_in_test.append(i)
252
- duplicate_to_original_mapping[i] = similar_indices[0]
253
 
254
- if i % 100 == 0:
255
- progress(i / total_items, desc="Processing duplicates across datasets")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
+ import concurrent.futures
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
 
30
  diff = ndiff(x.split(), y.split())
31
  return " ".join([word for word in diff if word.startswith(('+', '-'))])
32
 
33
+ def perform_deduplication(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  deduplication_type,
35
  dataset1_name,
36
  dataset1_split,
 
65
  # Compute embeddings
66
  status = "Computing embeddings for Dataset 1..."
67
  yield status, ""
68
+ embeddings = []
69
+ batch_size = 64
70
+ total_batches = (len(texts) + batch_size - 1) // batch_size
71
+
72
+ def compute_embeddings():
73
+ for batch_texts in progress.tqdm(batch_iterable(texts, batch_size), desc="Computing embeddings for Dataset 1", total=total_batches):
74
+ batch_embeddings = model.encode(batch_texts, show_progressbar=False)
75
+ embeddings.append(batch_embeddings)
76
+ return np.concatenate(embeddings, axis=0)
77
+
78
+ with concurrent.futures.ThreadPoolExecutor() as executor:
79
+ future = executor.submit(compute_embeddings)
80
+ while not future.done():
81
+ pass # Wait for embeddings to be computed
82
+ embedding_matrix = future.result()
83
 
84
  # Deduplicate
85
  status = "Deduplicating embeddings..."
86
  yield status, ""
87
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(
88
  embedding_matrix, threshold, progress=progress
89
  )
90
 
 
117
  yield status, result_text
118
 
119
  elif deduplication_type == "Cross-dataset":
120
+ # Similar code for cross-dataset deduplication
121
+ # Implement similar logic as above for cross-dataset
122
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  except Exception as e:
125
  yield f"An error occurred: {e}", ""
126
  raise e
127
 
128
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
129
  """
130
+ Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
131
  """
132
+ # Building the index
133
+ progress(0, desc="Building search index...")
134
+ reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
135
 
136
+ deduplicated_indices = set(range(len(embedding_matrix)))
137
  duplicate_to_original_mapping = {}
138
 
139
+ # Finding nearest neighbors
140
+ progress(0, desc="Finding nearest neighbors...")
141
+ results = reach.nearest_neighbor_threshold(
142
+ embedding_matrix,
143
+ threshold=threshold,
144
+ batch_size=batch_size,
145
+ show_progressbar=False # Disable internal progress bar
146
+ )
147
+
148
+ # Processing duplicates with a progress bar
149
+ total_items = len(embedding_matrix)
150
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
151
+ if i not in deduplicated_indices:
152
+ continue
153
+
154
+ similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
155
 
156
+ for sim_idx in similar_indices:
157
+ if sim_idx in deduplicated_indices:
158
+ deduplicated_indices.remove(sim_idx)
159
+ duplicate_to_original_mapping[sim_idx] = i
160
 
161
+ return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
 
 
162
 
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown("# Semantic Deduplication")
165
+
166
+ deduplication_type = gr.Radio(
167
+ choices=["Single dataset", "Cross-dataset"],
168
+ label="Deduplication Type",
169
+ value="Single dataset"
170
+ )
171
+
172
+ with gr.Row():
173
+ dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
174
+ dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
175
+ dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
176
+
177
+ dataset2_inputs = gr.Column(visible=False)
178
+ with dataset2_inputs:
179
+ gr.Markdown("### Dataset 2")
180
+ with gr.Row():
181
+ dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
182
+ dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
183
+ dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
184
+
185
+ threshold = gr.Slider(
186
+ minimum=0.0,
187
+ maximum=1.0,
188
+ value=default_threshold,
189
+ label="Similarity Threshold"
190
+ )
191
+
192
+ compute_button = gr.Button("Compute")
193
+
194
+ status_output = gr.Markdown()
195
+ result_output = gr.Markdown()
196
+
197
+ # Function to update the visibility of dataset2_inputs
198
+ def update_visibility(deduplication_type_value):
199
+ if deduplication_type_value == "Cross-dataset":
200
+ return gr.update(visible=True)
201
+ else:
202
+ return gr.update(visible=False)
203
+
204
+ deduplication_type.change(
205
+ update_visibility,
206
+ inputs=deduplication_type,
207
+ outputs=dataset2_inputs
208
+ )
209
+
210
+ compute_button.click(
211
+ fn=perform_deduplication,
212
+ inputs=[
213
+ deduplication_type,
214
+ dataset1_name,
215
+ dataset1_split,
216
+ dataset1_text_column,
217
+ dataset2_name,
218
+ dataset2_split,
219
+ dataset2_text_column,
220
+ threshold
221
+ ],
222
+ outputs=[status_output, result_output]
223
+ )
224
+
225
+ demo.launch()