Pringled commited on
Commit
5422464
1 Parent(s): 504b6fc

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +396 -104
app.py CHANGED
@@ -22,19 +22,17 @@ default_threshold = 0.9
22
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
23
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
24
 
25
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
26
  """
27
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
28
  """
29
- # Update progress to indicate building the index
30
- progress(0, desc="Building search index...")
31
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
32
 
33
  deduplicated_indices = set(range(len(embedding_matrix)))
34
  duplicate_to_original_mapping = {}
35
 
36
  # Finding nearest neighbors
37
- progress(0, desc="Finding nearest neighbors...")
38
  results = reach.nearest_neighbor_threshold(
39
  embedding_matrix,
40
  threshold=threshold,
@@ -42,9 +40,8 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
42
  show_progressbar=True # Allow internal progress bar
43
  )
44
 
45
- # Processing duplicates with a progress bar
46
- total_items = len(embedding_matrix)
47
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
48
  if i not in deduplicated_indices:
49
  continue
50
 
@@ -57,19 +54,17 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
57
 
58
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
59
 
60
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
61
  """
62
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
63
  """
64
- # Update progress to indicate building the index
65
- progress(0, desc="Building search index from Dataset 1...")
66
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
67
 
68
  duplicate_indices_in_test = []
69
  duplicate_to_original_mapping = {}
70
 
71
  # Finding nearest neighbors between datasets
72
- progress(0, desc="Finding nearest neighbors between datasets...")
73
  results = reach.nearest_neighbor_threshold(
74
  embedding_matrix_2,
75
  threshold=threshold,
@@ -77,9 +72,8 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
77
  show_progressbar=True # Allow internal progress bar
78
  )
79
 
80
- total_items = len(embedding_matrix_2)
81
- # Processing duplicates with a progress bar
82
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
83
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
84
 
85
  if similar_indices:
@@ -103,13 +97,12 @@ def perform_deduplication(
103
  threshold=default_threshold,
104
  progress=gr.Progress(track_tqdm=True)
105
  ):
106
- # Monkey-patch tqdm
107
  original_tqdm = tqdm.tqdm
108
- original_reach_tqdm = Reach.__dict__['tqdm'] if 'tqdm' in Reach.__dict__ else None
109
  tqdm.tqdm = progress.tqdm
110
- sys.modules['tqdm'].tqdm = progress.tqdm
111
- sys.modules['tqdm.auto'].tqdm = progress.tqdm
112
- Reach.tqdm = progress.tqdm # Monkey-patch reach's tqdm
113
 
114
  try:
115
  # Convert threshold to float
@@ -117,140 +110,121 @@ def perform_deduplication(
117
 
118
  if deduplication_type == "Single dataset":
119
  # Load Dataset 1
120
- progress(0, desc="Loading Dataset 1...")
121
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
122
  ds = ds_default1
123
  else:
124
  ds = load_dataset(dataset1_name, split=dataset1_split)
125
 
126
  # Extract texts
127
- progress(0, desc="Extracting texts from Dataset 1...")
128
  texts = [example[dataset1_text_column] for example in ds]
129
 
130
  # Compute embeddings
131
- progress(0, desc="Computing embeddings for Dataset 1...")
132
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
133
 
134
  # Deduplicate
135
- result_text = deduplicate_and_prepare_results_single(
136
- embedding_matrix, texts, threshold, progress
 
137
  )
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return result_text
140
 
141
  elif deduplication_type == "Cross-dataset":
142
  # Load Dataset 1
143
- progress(0, desc="Loading Dataset 1...")
144
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
145
  ds1 = ds_default1
146
  else:
147
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
148
 
149
  # Load Dataset 2
150
- progress(0, desc="Loading Dataset 2...")
151
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
152
  ds2 = ds_default2
153
  else:
154
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
155
 
156
  # Extract texts from Dataset 1
157
- progress(0, desc="Extracting texts from Dataset 1...")
158
  texts1 = [example[dataset1_text_column] for example in ds1]
159
 
160
  # Extract texts from Dataset 2
161
- progress(0, desc="Extracting texts from Dataset 2...")
162
  texts2 = [example[dataset2_text_column] for example in ds2]
163
 
164
  # Compute embeddings for Dataset 1
165
- progress(0, desc="Computing embeddings for Dataset 1...")
166
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
167
 
168
  # Compute embeddings for Dataset 2
169
- progress(0, desc="Computing embeddings for Dataset 2...")
170
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
171
 
172
  # Deduplicate across datasets
173
- result_text = deduplicate_and_prepare_results_cross(
174
- embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split
 
175
  )
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  return result_text
178
 
179
  finally:
180
  # Restore original tqdm
181
  tqdm.tqdm = original_tqdm
182
- sys.modules['tqdm'].tqdm = original_tqdm
183
- sys.modules['tqdm.auto'].tqdm = original_tqdm
184
-
185
- # Restore reach's original tqdm
186
- if original_reach_tqdm is not None:
187
- Reach.tqdm = original_reach_tqdm
188
- else:
189
- del Reach.tqdm # If it wasn't originally in Reach's __dict__
190
-
191
- def deduplicate_and_prepare_results_single(embedding_matrix, texts, threshold, progress):
192
- # Deduplicate
193
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(
194
- embedding_matrix, threshold, progress=progress
195
- )
196
-
197
- # Prepare the results
198
- num_duplicates = len(duplicate_to_original_mapping)
199
- num_total = len(texts)
200
- num_deduplicated = len(deduplicated_indices)
201
-
202
- result_text = f"**Total documents:** {num_total}\n"
203
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
204
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
205
-
206
- # Show deduplicated examples
207
- if num_duplicates > 0:
208
- result_text += "**Examples of duplicates found:**\n\n"
209
- num_examples = min(5, num_duplicates)
210
- for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
211
- original_text = texts[original_idx]
212
- duplicate_text = texts[duplicate_idx]
213
- differences = display_word_differences(original_text, duplicate_text)
214
- result_text += f"**Original text:**\n{original_text}\n\n"
215
- result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
216
- result_text += f"**Differences:**\n{differences}\n"
217
- result_text += "-" * 50 + "\n\n"
218
- else:
219
- result_text += "No duplicates found."
220
-
221
- return result_text
222
-
223
- def deduplicate_and_prepare_results_cross(embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split):
224
- # Deduplicate across datasets
225
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
226
- embedding_matrix1, embedding_matrix2, threshold, progress=progress
227
- )
228
-
229
- num_duplicates = len(duplicate_indices_in_ds2)
230
- num_total_ds2 = len(texts2)
231
- num_unique_ds2 = num_total_ds2 - num_duplicates
232
-
233
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
234
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
235
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
236
-
237
- # Show deduplicated examples
238
- if num_duplicates > 0:
239
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
240
- num_examples = min(5, num_duplicates)
241
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
242
- original_idx = duplicate_to_original_mapping[duplicate_idx]
243
- original_text = texts1[original_idx]
244
- duplicate_text = texts2[duplicate_idx]
245
- differences = display_word_differences(original_text, duplicate_text)
246
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
247
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
248
- result_text += f"**Differences:**\n{differences}\n"
249
- result_text += "-" * 50 + "\n\n"
250
- else:
251
- result_text += "No duplicates found."
252
-
253
- return result_text
254
 
255
  with gr.Blocks() as demo:
256
  gr.Markdown("# Semantic Deduplication")
@@ -316,6 +290,324 @@ with gr.Blocks() as demo:
316
  demo.launch()
317
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  # import gradio as gr
 
22
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
23
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
24
 
25
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
26
  """
27
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
28
  """
29
+ # Building the index
 
30
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
31
 
32
  deduplicated_indices = set(range(len(embedding_matrix)))
33
  duplicate_to_original_mapping = {}
34
 
35
  # Finding nearest neighbors
 
36
  results = reach.nearest_neighbor_threshold(
37
  embedding_matrix,
38
  threshold=threshold,
 
40
  show_progressbar=True # Allow internal progress bar
41
  )
42
 
43
+ # Processing duplicates
44
+ for i, similar_items in enumerate(results):
 
45
  if i not in deduplicated_indices:
46
  continue
47
 
 
54
 
55
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
56
 
57
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
58
  """
59
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
60
  """
61
+ # Building the index from Dataset 1
 
62
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
63
 
64
  duplicate_indices_in_test = []
65
  duplicate_to_original_mapping = {}
66
 
67
  # Finding nearest neighbors between datasets
 
68
  results = reach.nearest_neighbor_threshold(
69
  embedding_matrix_2,
70
  threshold=threshold,
 
72
  show_progressbar=True # Allow internal progress bar
73
  )
74
 
75
+ # Processing duplicates
76
+ for i, similar_items in enumerate(results):
 
77
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
78
 
79
  if similar_indices:
 
97
  threshold=default_threshold,
98
  progress=gr.Progress(track_tqdm=True)
99
  ):
100
+ # Deep Monkey-Patching of tqdm
101
  original_tqdm = tqdm.tqdm
 
102
  tqdm.tqdm = progress.tqdm
103
+ for mod_name in list(sys.modules.keys()):
104
+ if 'tqdm' in mod_name:
105
+ sys.modules[mod_name].tqdm = progress.tqdm
106
 
107
  try:
108
  # Convert threshold to float
 
110
 
111
  if deduplication_type == "Single dataset":
112
  # Load Dataset 1
113
+ gr.print("Loading Dataset 1...")
114
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
115
  ds = ds_default1
116
  else:
117
  ds = load_dataset(dataset1_name, split=dataset1_split)
118
 
119
  # Extract texts
120
+ gr.print("Extracting texts from Dataset 1...")
121
  texts = [example[dataset1_text_column] for example in ds]
122
 
123
  # Compute embeddings
124
+ gr.print("Computing embeddings for Dataset 1...")
125
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
126
 
127
  # Deduplicate
128
+ gr.print("Deduplicating embeddings...")
129
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(
130
+ embedding_matrix, threshold
131
  )
132
 
133
+ # Prepare the results
134
+ num_duplicates = len(duplicate_to_original_mapping)
135
+ num_total = len(texts)
136
+ num_deduplicated = len(deduplicated_indices)
137
+
138
+ result_text = f"**Total documents:** {num_total}\n"
139
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
140
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
141
+
142
+ # Show deduplicated examples
143
+ if num_duplicates > 0:
144
+ result_text += "**Examples of duplicates found:**\n\n"
145
+ num_examples = min(5, num_duplicates)
146
+ for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
147
+ original_text = texts[original_idx]
148
+ duplicate_text = texts[duplicate_idx]
149
+ differences = display_word_differences(original_text, duplicate_text)
150
+ result_text += f"**Original text:**\n{original_text}\n\n"
151
+ result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
152
+ result_text += f"**Differences:**\n{differences}\n"
153
+ result_text += "-" * 50 + "\n\n"
154
+ else:
155
+ result_text += "No duplicates found."
156
+
157
  return result_text
158
 
159
  elif deduplication_type == "Cross-dataset":
160
  # Load Dataset 1
161
+ gr.print("Loading Dataset 1...")
162
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
163
  ds1 = ds_default1
164
  else:
165
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
166
 
167
  # Load Dataset 2
168
+ gr.print("Loading Dataset 2...")
169
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
170
  ds2 = ds_default2
171
  else:
172
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
173
 
174
  # Extract texts from Dataset 1
175
+ gr.print("Extracting texts from Dataset 1...")
176
  texts1 = [example[dataset1_text_column] for example in ds1]
177
 
178
  # Extract texts from Dataset 2
179
+ gr.print("Extracting texts from Dataset 2...")
180
  texts2 = [example[dataset2_text_column] for example in ds2]
181
 
182
  # Compute embeddings for Dataset 1
183
+ gr.print("Computing embeddings for Dataset 1...")
184
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
185
 
186
  # Compute embeddings for Dataset 2
187
+ gr.print("Computing embeddings for Dataset 2...")
188
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
189
 
190
  # Deduplicate across datasets
191
+ gr.print("Deduplicating embeddings across datasets...")
192
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
193
+ embedding_matrix1, embedding_matrix2, threshold
194
  )
195
 
196
+ num_duplicates = len(duplicate_indices_in_ds2)
197
+ num_total_ds2 = len(texts2)
198
+ num_unique_ds2 = num_total_ds2 - num_duplicates
199
+
200
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
201
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
202
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
203
+
204
+ # Show deduplicated examples
205
+ if num_duplicates > 0:
206
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
207
+ num_examples = min(5, num_duplicates)
208
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
209
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
210
+ original_text = texts1[original_idx]
211
+ duplicate_text = texts2[duplicate_idx]
212
+ differences = display_word_differences(original_text, duplicate_text)
213
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
214
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
215
+ result_text += f"**Differences:**\n{differences}\n"
216
+ result_text += "-" * 50 + "\n\n"
217
+ else:
218
+ result_text += "No duplicates found."
219
+
220
  return result_text
221
 
222
  finally:
223
  # Restore original tqdm
224
  tqdm.tqdm = original_tqdm
225
+ for mod_name in list(sys.modules.keys()):
226
+ if 'tqdm' in mod_name:
227
+ sys.modules[mod_name].tqdm = original_tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  with gr.Blocks() as demo:
230
  gr.Markdown("# Semantic Deduplication")
 
290
  demo.launch()
291
 
292
 
293
+ # import gradio as gr
294
+ # from datasets import load_dataset
295
+ # import numpy as np
296
+ # from model2vec import StaticModel
297
+ # from reach import Reach
298
+ # from difflib import ndiff
299
+ # import sys
300
+ # import tqdm
301
+
302
+ # # Load the model at startup
303
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
304
+
305
+ # # Update default dataset to 'sst2' and set default threshold to 0.9
306
+ # default_dataset1_name = "sst2"
307
+ # default_dataset1_split = "train"
308
+ # default_dataset2_name = "sst2"
309
+ # default_dataset2_split = "validation"
310
+ # default_text_column = "sentence"
311
+ # default_threshold = 0.9
312
+
313
+ # # Load the default datasets at startup
314
+ # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
315
+ # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
316
+
317
+ # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
318
+ # """
319
+ # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
320
+ # """
321
+ # # Update progress to indicate building the index
322
+ # progress(0, desc="Building search index...")
323
+ # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
324
+
325
+ # deduplicated_indices = set(range(len(embedding_matrix)))
326
+ # duplicate_to_original_mapping = {}
327
+
328
+ # # Finding nearest neighbors
329
+ # progress(0, desc="Finding nearest neighbors...")
330
+ # results = reach.nearest_neighbor_threshold(
331
+ # embedding_matrix,
332
+ # threshold=threshold,
333
+ # batch_size=batch_size,
334
+ # show_progressbar=True # Allow internal progress bar
335
+ # )
336
+
337
+ # # Processing duplicates with a progress bar
338
+ # total_items = len(embedding_matrix)
339
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
340
+ # if i not in deduplicated_indices:
341
+ # continue
342
+
343
+ # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
344
+
345
+ # for sim_idx in similar_indices:
346
+ # if sim_idx in deduplicated_indices:
347
+ # deduplicated_indices.remove(sim_idx)
348
+ # duplicate_to_original_mapping[sim_idx] = i
349
+
350
+ # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
351
+
352
+ # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
353
+ # """
354
+ # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
355
+ # """
356
+ # # Update progress to indicate building the index
357
+ # progress(0, desc="Building search index from Dataset 1...")
358
+ # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
359
+
360
+ # duplicate_indices_in_test = []
361
+ # duplicate_to_original_mapping = {}
362
+
363
+ # # Finding nearest neighbors between datasets
364
+ # progress(0, desc="Finding nearest neighbors between datasets...")
365
+ # results = reach.nearest_neighbor_threshold(
366
+ # embedding_matrix_2,
367
+ # threshold=threshold,
368
+ # batch_size=batch_size,
369
+ # show_progressbar=True # Allow internal progress bar
370
+ # )
371
+
372
+ # total_items = len(embedding_matrix_2)
373
+ # # Processing duplicates with a progress bar
374
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
375
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
376
+
377
+ # if similar_indices:
378
+ # duplicate_indices_in_test.append(i)
379
+ # duplicate_to_original_mapping[i] = similar_indices[0]
380
+
381
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
382
+
383
+ # def display_word_differences(x: str, y: str) -> str:
384
+ # diff = ndiff(x.split(), y.split())
385
+ # return " ".join([word for word in diff if word.startswith(('+', '-'))])
386
+
387
+ # def perform_deduplication(
388
+ # deduplication_type,
389
+ # dataset1_name,
390
+ # dataset1_split,
391
+ # dataset1_text_column,
392
+ # dataset2_name="",
393
+ # dataset2_split="",
394
+ # dataset2_text_column="",
395
+ # threshold=default_threshold,
396
+ # progress=gr.Progress(track_tqdm=True)
397
+ # ):
398
+ # # Monkey-patch tqdm
399
+ # original_tqdm = tqdm.tqdm
400
+ # original_reach_tqdm = Reach.__dict__['tqdm'] if 'tqdm' in Reach.__dict__ else None
401
+ # tqdm.tqdm = progress.tqdm
402
+ # sys.modules['tqdm'].tqdm = progress.tqdm
403
+ # sys.modules['tqdm.auto'].tqdm = progress.tqdm
404
+ # Reach.tqdm = progress.tqdm # Monkey-patch reach's tqdm
405
+
406
+ # try:
407
+ # # Convert threshold to float
408
+ # threshold = float(threshold)
409
+
410
+ # if deduplication_type == "Single dataset":
411
+ # # Load Dataset 1
412
+ # progress(0, desc="Loading Dataset 1...")
413
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
414
+ # ds = ds_default1
415
+ # else:
416
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
417
+
418
+ # # Extract texts
419
+ # progress(0, desc="Extracting texts from Dataset 1...")
420
+ # texts = [example[dataset1_text_column] for example in ds]
421
+
422
+ # # Compute embeddings
423
+ # progress(0, desc="Computing embeddings for Dataset 1...")
424
+ # embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
425
+
426
+ # # Deduplicate
427
+ # result_text = deduplicate_and_prepare_results_single(
428
+ # embedding_matrix, texts, threshold, progress
429
+ # )
430
+
431
+ # return result_text
432
+
433
+ # elif deduplication_type == "Cross-dataset":
434
+ # # Load Dataset 1
435
+ # progress(0, desc="Loading Dataset 1...")
436
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
437
+ # ds1 = ds_default1
438
+ # else:
439
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
440
+
441
+ # # Load Dataset 2
442
+ # progress(0, desc="Loading Dataset 2...")
443
+ # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
444
+ # ds2 = ds_default2
445
+ # else:
446
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
447
+
448
+ # # Extract texts from Dataset 1
449
+ # progress(0, desc="Extracting texts from Dataset 1...")
450
+ # texts1 = [example[dataset1_text_column] for example in ds1]
451
+
452
+ # # Extract texts from Dataset 2
453
+ # progress(0, desc="Extracting texts from Dataset 2...")
454
+ # texts2 = [example[dataset2_text_column] for example in ds2]
455
+
456
+ # # Compute embeddings for Dataset 1
457
+ # progress(0, desc="Computing embeddings for Dataset 1...")
458
+ # embedding_matrix1 = model.encode(texts1, show_progressbar=True)
459
+
460
+ # # Compute embeddings for Dataset 2
461
+ # progress(0, desc="Computing embeddings for Dataset 2...")
462
+ # embedding_matrix2 = model.encode(texts2, show_progressbar=True)
463
+
464
+ # # Deduplicate across datasets
465
+ # result_text = deduplicate_and_prepare_results_cross(
466
+ # embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split
467
+ # )
468
+
469
+ # return result_text
470
+
471
+ # finally:
472
+ # # Restore original tqdm
473
+ # tqdm.tqdm = original_tqdm
474
+ # sys.modules['tqdm'].tqdm = original_tqdm
475
+ # sys.modules['tqdm.auto'].tqdm = original_tqdm
476
+
477
+ # # Restore reach's original tqdm
478
+ # if original_reach_tqdm is not None:
479
+ # Reach.tqdm = original_reach_tqdm
480
+ # else:
481
+ # del Reach.tqdm # If it wasn't originally in Reach's __dict__
482
+
483
+ # def deduplicate_and_prepare_results_single(embedding_matrix, texts, threshold, progress):
484
+ # # Deduplicate
485
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
486
+ # embedding_matrix, threshold, progress=progress
487
+ # )
488
+
489
+ # # Prepare the results
490
+ # num_duplicates = len(duplicate_to_original_mapping)
491
+ # num_total = len(texts)
492
+ # num_deduplicated = len(deduplicated_indices)
493
+
494
+ # result_text = f"**Total documents:** {num_total}\n"
495
+ # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
496
+ # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
497
+
498
+ # # Show deduplicated examples
499
+ # if num_duplicates > 0:
500
+ # result_text += "**Examples of duplicates found:**\n\n"
501
+ # num_examples = min(5, num_duplicates)
502
+ # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
503
+ # original_text = texts[original_idx]
504
+ # duplicate_text = texts[duplicate_idx]
505
+ # differences = display_word_differences(original_text, duplicate_text)
506
+ # result_text += f"**Original text:**\n{original_text}\n\n"
507
+ # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
508
+ # result_text += f"**Differences:**\n{differences}\n"
509
+ # result_text += "-" * 50 + "\n\n"
510
+ # else:
511
+ # result_text += "No duplicates found."
512
+
513
+ # return result_text
514
+
515
+ # def deduplicate_and_prepare_results_cross(embedding_matrix1, embedding_matrix2, texts1, texts2, threshold, progress, dataset2_name, dataset2_split):
516
+ # # Deduplicate across datasets
517
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
518
+ # embedding_matrix1, embedding_matrix2, threshold, progress=progress
519
+ # )
520
+
521
+ # num_duplicates = len(duplicate_indices_in_ds2)
522
+ # num_total_ds2 = len(texts2)
523
+ # num_unique_ds2 = num_total_ds2 - num_duplicates
524
+
525
+ # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
526
+ # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
527
+ # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
528
+
529
+ # # Show deduplicated examples
530
+ # if num_duplicates > 0:
531
+ # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
532
+ # num_examples = min(5, num_duplicates)
533
+ # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
534
+ # original_idx = duplicate_to_original_mapping[duplicate_idx]
535
+ # original_text = texts1[original_idx]
536
+ # duplicate_text = texts2[duplicate_idx]
537
+ # differences = display_word_differences(original_text, duplicate_text)
538
+ # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
539
+ # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
540
+ # result_text += f"**Differences:**\n{differences}\n"
541
+ # result_text += "-" * 50 + "\n\n"
542
+ # else:
543
+ # result_text += "No duplicates found."
544
+
545
+ # return result_text
546
+
547
+ # with gr.Blocks() as demo:
548
+ # gr.Markdown("# Semantic Deduplication")
549
+
550
+ # deduplication_type = gr.Radio(
551
+ # choices=["Single dataset", "Cross-dataset"],
552
+ # label="Deduplication Type",
553
+ # value="Single dataset"
554
+ # )
555
+
556
+ # with gr.Row():
557
+ # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
558
+ # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
559
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
560
+
561
+ # dataset2_inputs = gr.Column(visible=False)
562
+ # with dataset2_inputs:
563
+ # gr.Markdown("### Dataset 2")
564
+ # with gr.Row():
565
+ # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
566
+ # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
567
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
568
+
569
+ # threshold = gr.Slider(
570
+ # minimum=0.0,
571
+ # maximum=1.0,
572
+ # value=default_threshold,
573
+ # label="Similarity Threshold"
574
+ # )
575
+
576
+ # compute_button = gr.Button("Compute")
577
+
578
+ # output = gr.Markdown()
579
+
580
+ # # Function to update the visibility of dataset2_inputs
581
+ # def update_visibility(deduplication_type_value):
582
+ # if deduplication_type_value == "Cross-dataset":
583
+ # return gr.update(visible=True)
584
+ # else:
585
+ # return gr.update(visible=False)
586
+
587
+ # deduplication_type.change(
588
+ # update_visibility,
589
+ # inputs=deduplication_type,
590
+ # outputs=dataset2_inputs
591
+ # )
592
+
593
+ # compute_button.click(
594
+ # fn=perform_deduplication,
595
+ # inputs=[
596
+ # deduplication_type,
597
+ # dataset1_name,
598
+ # dataset1_split,
599
+ # dataset1_text_column,
600
+ # dataset2_name,
601
+ # dataset2_split,
602
+ # dataset2_text_column,
603
+ # threshold
604
+ # ],
605
+ # outputs=output
606
+ # )
607
+
608
+ # demo.launch()
609
+
610
+
611
 
612
 
613
  # import gradio as gr