Pringled commited on
Commit
3bd0812
1 Parent(s): 9f13004

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +366 -36
app.py CHANGED
@@ -26,13 +26,13 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
26
  """
27
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
28
  """
29
- # Building the index
30
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
31
 
32
  deduplicated_indices = set(range(len(embedding_matrix)))
33
  duplicate_to_original_mapping = {}
34
 
35
- # Finding nearest neighbors
36
  results = reach.nearest_neighbor_threshold(
37
  embedding_matrix,
38
  threshold=threshold,
@@ -40,7 +40,7 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
40
  show_progressbar=True # Allow internal progress bar
41
  )
42
 
43
- # Processing duplicates
44
  for i, similar_items in enumerate(results):
45
  if i not in deduplicated_indices:
46
  continue
@@ -58,13 +58,13 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
58
  """
59
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
60
  """
61
- # Building the index from Dataset 1
62
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
63
 
64
  duplicate_indices_in_test = []
65
  duplicate_to_original_mapping = {}
66
 
67
- # Finding nearest neighbors between datasets
68
  results = reach.nearest_neighbor_threshold(
69
  embedding_matrix_2,
70
  threshold=threshold,
@@ -72,7 +72,7 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
72
  show_progressbar=True # Allow internal progress bar
73
  )
74
 
75
- # Processing duplicates
76
  for i, similar_items in enumerate(results):
77
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
78
 
@@ -97,17 +97,35 @@ def perform_deduplication(
97
  threshold=default_threshold,
98
  progress=gr.Progress(track_tqdm=True)
99
  ):
100
- # Deep Monkey-Patching of tqdm
101
- original_tqdm = tqdm.tqdm
 
 
 
 
 
 
 
 
 
 
102
  tqdm.tqdm = progress.tqdm
103
- for mod_name in list(sys.modules.keys()):
104
- if 'tqdm' in mod_name:
105
- sys.modules[mod_name].tqdm = progress.tqdm
 
 
 
 
 
 
 
 
106
 
107
  try:
108
  # Convert threshold to float
109
  threshold = float(threshold)
110
-
111
  # Initialize status message
112
  status = ""
113
 
@@ -119,33 +137,33 @@ def perform_deduplication(
119
  ds = ds_default1
120
  else:
121
  ds = load_dataset(dataset1_name, split=dataset1_split)
122
-
123
  # Extract texts
124
  status = "Extracting texts from Dataset 1..."
125
  yield status, ""
126
  texts = [example[dataset1_text_column] for example in ds]
127
-
128
  # Compute embeddings
129
  status = "Computing embeddings for Dataset 1..."
130
  yield status, ""
131
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
132
-
133
  # Deduplicate
134
  status = "Deduplicating embeddings..."
135
  yield status, ""
136
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
137
  embedding_matrix, threshold
138
  )
139
-
140
  # Prepare the results
141
  num_duplicates = len(duplicate_to_original_mapping)
142
  num_total = len(texts)
143
  num_deduplicated = len(deduplicated_indices)
144
-
145
  result_text = f"**Total documents:** {num_total}\n"
146
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
147
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
148
-
149
  # Show deduplicated examples
150
  if num_duplicates > 0:
151
  result_text += "**Examples of duplicates found:**\n\n"
@@ -160,11 +178,11 @@ def perform_deduplication(
160
  result_text += "-" * 50 + "\n\n"
161
  else:
162
  result_text += "No duplicates found."
163
-
164
  # Final status
165
  status = "Deduplication completed."
166
  yield status, result_text
167
-
168
  elif deduplication_type == "Cross-dataset":
169
  # Load Dataset 1
170
  status = "Loading Dataset 1..."
@@ -173,7 +191,7 @@ def perform_deduplication(
173
  ds1 = ds_default1
174
  else:
175
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
176
-
177
  # Load Dataset 2
178
  status = "Loading Dataset 2..."
179
  yield status, ""
@@ -181,42 +199,42 @@ def perform_deduplication(
181
  ds2 = ds_default2
182
  else:
183
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
184
-
185
  # Extract texts from Dataset 1
186
  status = "Extracting texts from Dataset 1..."
187
  yield status, ""
188
  texts1 = [example[dataset1_text_column] for example in ds1]
189
-
190
  # Extract texts from Dataset 2
191
  status = "Extracting texts from Dataset 2..."
192
  yield status, ""
193
  texts2 = [example[dataset2_text_column] for example in ds2]
194
-
195
  # Compute embeddings for Dataset 1
196
  status = "Computing embeddings for Dataset 1..."
197
  yield status, ""
198
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
199
-
200
  # Compute embeddings for Dataset 2
201
  status = "Computing embeddings for Dataset 2..."
202
  yield status, ""
203
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
204
-
205
  # Deduplicate across datasets
206
  status = "Deduplicating embeddings across datasets..."
207
  yield status, ""
208
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
209
  embedding_matrix1, embedding_matrix2, threshold
210
  )
211
-
212
  num_duplicates = len(duplicate_indices_in_ds2)
213
  num_total_ds2 = len(texts2)
214
  num_unique_ds2 = num_total_ds2 - num_duplicates
215
-
216
  result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
217
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
218
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
219
-
220
  # Show deduplicated examples
221
  if num_duplicates > 0:
222
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
@@ -232,17 +250,18 @@ def perform_deduplication(
232
  result_text += "-" * 50 + "\n\n"
233
  else:
234
  result_text += "No duplicates found."
235
-
236
  # Final status
237
  status = "Deduplication completed."
238
  yield status, result_text
239
 
240
  finally:
241
- # Restore original tqdm
242
- tqdm.tqdm = original_tqdm
243
- for mod_name in list(sys.modules.keys()):
244
- if 'tqdm' in mod_name:
245
- sys.modules[mod_name].tqdm = original_tqdm
 
246
 
247
  with gr.Blocks() as demo:
248
  gr.Markdown("# Semantic Deduplication")
@@ -305,10 +324,321 @@ with gr.Blocks() as demo:
305
  ],
306
  outputs=[status_output, result_output]
307
  )
308
-
309
  demo.launch()
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  # import gradio as gr
313
  # from datasets import load_dataset
314
  # import numpy as np
 
26
  """
27
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
28
  """
29
+ # Build the index
30
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
31
 
32
  deduplicated_indices = set(range(len(embedding_matrix)))
33
  duplicate_to_original_mapping = {}
34
 
35
+ # Find nearest neighbors
36
  results = reach.nearest_neighbor_threshold(
37
  embedding_matrix,
38
  threshold=threshold,
 
40
  show_progressbar=True # Allow internal progress bar
41
  )
42
 
43
+ # Process duplicates
44
  for i, similar_items in enumerate(results):
45
  if i not in deduplicated_indices:
46
  continue
 
58
  """
59
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
60
  """
61
+ # Build the index from Dataset 1
62
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
63
 
64
  duplicate_indices_in_test = []
65
  duplicate_to_original_mapping = {}
66
 
67
+ # Find nearest neighbors between datasets
68
  results = reach.nearest_neighbor_threshold(
69
  embedding_matrix_2,
70
  threshold=threshold,
 
72
  show_progressbar=True # Allow internal progress bar
73
  )
74
 
75
+ # Process duplicates
76
  for i, similar_items in enumerate(results):
77
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
78
 
 
97
  threshold=default_threshold,
98
  progress=gr.Progress(track_tqdm=True)
99
  ):
100
+ # Custom tqdm class that wraps progress.tqdm and includes module-level attributes
101
+ class TqdmWrapper(tqdm.std.tqdm):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__(*args, **kwargs)
104
+
105
+ # Copy module-level attributes from original tqdm module
106
+ TqdmWrapper.format_interval = staticmethod(tqdm.format_interval)
107
+ TqdmWrapper.format_num = staticmethod(tqdm.format_num)
108
+ TqdmWrapper.format_sizeof = staticmethod(tqdm.format_sizeof)
109
+
110
+ # Monkey-patch tqdm.tqdm with our wrapper
111
+ original_tqdm_tqdm = tqdm.tqdm
112
  tqdm.tqdm = progress.tqdm
113
+
114
+ # Monkey-patch model2vec's tqdm reference if needed
115
+ import model2vec.model
116
+ if hasattr(model2vec.model, 'tqdm'):
117
+ original_model2vec_tqdm = model2vec.model.tqdm
118
+ model2vec.model.tqdm = TqdmWrapper
119
+
120
+ # Monkey-patch reach's tqdm reference if needed
121
+ if hasattr(Reach, 'tqdm'):
122
+ original_reach_tqdm = Reach.tqdm
123
+ Reach.tqdm = TqdmWrapper
124
 
125
  try:
126
  # Convert threshold to float
127
  threshold = float(threshold)
128
+
129
  # Initialize status message
130
  status = ""
131
 
 
137
  ds = ds_default1
138
  else:
139
  ds = load_dataset(dataset1_name, split=dataset1_split)
140
+
141
  # Extract texts
142
  status = "Extracting texts from Dataset 1..."
143
  yield status, ""
144
  texts = [example[dataset1_text_column] for example in ds]
145
+
146
  # Compute embeddings
147
  status = "Computing embeddings for Dataset 1..."
148
  yield status, ""
149
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
150
+
151
  # Deduplicate
152
  status = "Deduplicating embeddings..."
153
  yield status, ""
154
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
155
  embedding_matrix, threshold
156
  )
157
+
158
  # Prepare the results
159
  num_duplicates = len(duplicate_to_original_mapping)
160
  num_total = len(texts)
161
  num_deduplicated = len(deduplicated_indices)
162
+
163
  result_text = f"**Total documents:** {num_total}\n"
164
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
165
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
166
+
167
  # Show deduplicated examples
168
  if num_duplicates > 0:
169
  result_text += "**Examples of duplicates found:**\n\n"
 
178
  result_text += "-" * 50 + "\n\n"
179
  else:
180
  result_text += "No duplicates found."
181
+
182
  # Final status
183
  status = "Deduplication completed."
184
  yield status, result_text
185
+
186
  elif deduplication_type == "Cross-dataset":
187
  # Load Dataset 1
188
  status = "Loading Dataset 1..."
 
191
  ds1 = ds_default1
192
  else:
193
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
194
+
195
  # Load Dataset 2
196
  status = "Loading Dataset 2..."
197
  yield status, ""
 
199
  ds2 = ds_default2
200
  else:
201
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
202
+
203
  # Extract texts from Dataset 1
204
  status = "Extracting texts from Dataset 1..."
205
  yield status, ""
206
  texts1 = [example[dataset1_text_column] for example in ds1]
207
+
208
  # Extract texts from Dataset 2
209
  status = "Extracting texts from Dataset 2..."
210
  yield status, ""
211
  texts2 = [example[dataset2_text_column] for example in ds2]
212
+
213
  # Compute embeddings for Dataset 1
214
  status = "Computing embeddings for Dataset 1..."
215
  yield status, ""
216
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
217
+
218
  # Compute embeddings for Dataset 2
219
  status = "Computing embeddings for Dataset 2..."
220
  yield status, ""
221
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
222
+
223
  # Deduplicate across datasets
224
  status = "Deduplicating embeddings across datasets..."
225
  yield status, ""
226
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
227
  embedding_matrix1, embedding_matrix2, threshold
228
  )
229
+
230
  num_duplicates = len(duplicate_indices_in_ds2)
231
  num_total_ds2 = len(texts2)
232
  num_unique_ds2 = num_total_ds2 - num_duplicates
233
+
234
  result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
235
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
236
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
237
+
238
  # Show deduplicated examples
239
  if num_duplicates > 0:
240
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
 
250
  result_text += "-" * 50 + "\n\n"
251
  else:
252
  result_text += "No duplicates found."
253
+
254
  # Final status
255
  status = "Deduplication completed."
256
  yield status, result_text
257
 
258
  finally:
259
+ # Restore original tqdm functions
260
+ tqdm.tqdm = original_tqdm_tqdm
261
+ if hasattr(model2vec.model, 'tqdm'):
262
+ model2vec.model.tqdm = original_model2vec_tqdm
263
+ if hasattr(Reach, 'tqdm'):
264
+ Reach.tqdm = original_reach_tqdm
265
 
266
  with gr.Blocks() as demo:
267
  gr.Markdown("# Semantic Deduplication")
 
324
  ],
325
  outputs=[status_output, result_output]
326
  )
327
+
328
  demo.launch()
329
 
330
 
331
+ # import gradio as gr
332
+ # from datasets import load_dataset
333
+ # import numpy as np
334
+ # from model2vec import StaticModel
335
+ # from reach import Reach
336
+ # from difflib import ndiff
337
+ # import sys
338
+ # import tqdm
339
+
340
+ # # Load the model at startup
341
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
342
+
343
+ # # Update default dataset to 'sst2' and set default threshold to 0.9
344
+ # default_dataset1_name = "sst2"
345
+ # default_dataset1_split = "train"
346
+ # default_dataset2_name = "sst2"
347
+ # default_dataset2_split = "validation"
348
+ # default_text_column = "sentence"
349
+ # default_threshold = 0.9
350
+
351
+ # # Load the default datasets at startup
352
+ # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
353
+ # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
354
+
355
+ # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
356
+ # """
357
+ # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
358
+ # """
359
+ # # Building the index
360
+ # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
361
+
362
+ # deduplicated_indices = set(range(len(embedding_matrix)))
363
+ # duplicate_to_original_mapping = {}
364
+
365
+ # # Finding nearest neighbors
366
+ # results = reach.nearest_neighbor_threshold(
367
+ # embedding_matrix,
368
+ # threshold=threshold,
369
+ # batch_size=batch_size,
370
+ # show_progressbar=True # Allow internal progress bar
371
+ # )
372
+
373
+ # # Processing duplicates
374
+ # for i, similar_items in enumerate(results):
375
+ # if i not in deduplicated_indices:
376
+ # continue
377
+
378
+ # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
379
+
380
+ # for sim_idx in similar_indices:
381
+ # if sim_idx in deduplicated_indices:
382
+ # deduplicated_indices.remove(sim_idx)
383
+ # duplicate_to_original_mapping[sim_idx] = i
384
+
385
+ # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
386
+
387
+ # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
388
+ # """
389
+ # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
390
+ # """
391
+ # # Building the index from Dataset 1
392
+ # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
393
+
394
+ # duplicate_indices_in_test = []
395
+ # duplicate_to_original_mapping = {}
396
+
397
+ # # Finding nearest neighbors between datasets
398
+ # results = reach.nearest_neighbor_threshold(
399
+ # embedding_matrix_2,
400
+ # threshold=threshold,
401
+ # batch_size=batch_size,
402
+ # show_progressbar=True # Allow internal progress bar
403
+ # )
404
+
405
+ # # Processing duplicates
406
+ # for i, similar_items in enumerate(results):
407
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
408
+
409
+ # if similar_indices:
410
+ # duplicate_indices_in_test.append(i)
411
+ # duplicate_to_original_mapping[i] = similar_indices[0]
412
+
413
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
414
+
415
+ # def display_word_differences(x: str, y: str) -> str:
416
+ # diff = ndiff(x.split(), y.split())
417
+ # return " ".join([word for word in diff if word.startswith(('+', '-'))])
418
+
419
+ # def perform_deduplication(
420
+ # deduplication_type,
421
+ # dataset1_name,
422
+ # dataset1_split,
423
+ # dataset1_text_column,
424
+ # dataset2_name="",
425
+ # dataset2_split="",
426
+ # dataset2_text_column="",
427
+ # threshold=default_threshold,
428
+ # progress=gr.Progress(track_tqdm=True)
429
+ # ):
430
+ # # Deep Monkey-Patching of tqdm
431
+ # original_tqdm = tqdm.tqdm
432
+ # tqdm.tqdm = progress.tqdm
433
+ # for mod_name in list(sys.modules.keys()):
434
+ # if 'tqdm' in mod_name:
435
+ # sys.modules[mod_name].tqdm = progress.tqdm
436
+
437
+ # try:
438
+ # # Convert threshold to float
439
+ # threshold = float(threshold)
440
+
441
+ # # Initialize status message
442
+ # status = ""
443
+
444
+ # if deduplication_type == "Single dataset":
445
+ # # Load Dataset 1
446
+ # status = "Loading Dataset 1..."
447
+ # yield status, ""
448
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
449
+ # ds = ds_default1
450
+ # else:
451
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
452
+
453
+ # # Extract texts
454
+ # status = "Extracting texts from Dataset 1..."
455
+ # yield status, ""
456
+ # texts = [example[dataset1_text_column] for example in ds]
457
+
458
+ # # Compute embeddings
459
+ # status = "Computing embeddings for Dataset 1..."
460
+ # yield status, ""
461
+ # embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
462
+
463
+ # # Deduplicate
464
+ # status = "Deduplicating embeddings..."
465
+ # yield status, ""
466
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
467
+ # embedding_matrix, threshold
468
+ # )
469
+
470
+ # # Prepare the results
471
+ # num_duplicates = len(duplicate_to_original_mapping)
472
+ # num_total = len(texts)
473
+ # num_deduplicated = len(deduplicated_indices)
474
+
475
+ # result_text = f"**Total documents:** {num_total}\n"
476
+ # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
477
+ # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
478
+
479
+ # # Show deduplicated examples
480
+ # if num_duplicates > 0:
481
+ # result_text += "**Examples of duplicates found:**\n\n"
482
+ # num_examples = min(5, num_duplicates)
483
+ # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
484
+ # original_text = texts[original_idx]
485
+ # duplicate_text = texts[duplicate_idx]
486
+ # differences = display_word_differences(original_text, duplicate_text)
487
+ # result_text += f"**Original text:**\n{original_text}\n\n"
488
+ # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
489
+ # result_text += f"**Differences:**\n{differences}\n"
490
+ # result_text += "-" * 50 + "\n\n"
491
+ # else:
492
+ # result_text += "No duplicates found."
493
+
494
+ # # Final status
495
+ # status = "Deduplication completed."
496
+ # yield status, result_text
497
+
498
+ # elif deduplication_type == "Cross-dataset":
499
+ # # Load Dataset 1
500
+ # status = "Loading Dataset 1..."
501
+ # yield status, ""
502
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
503
+ # ds1 = ds_default1
504
+ # else:
505
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
506
+
507
+ # # Load Dataset 2
508
+ # status = "Loading Dataset 2..."
509
+ # yield status, ""
510
+ # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
511
+ # ds2 = ds_default2
512
+ # else:
513
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
514
+
515
+ # # Extract texts from Dataset 1
516
+ # status = "Extracting texts from Dataset 1..."
517
+ # yield status, ""
518
+ # texts1 = [example[dataset1_text_column] for example in ds1]
519
+
520
+ # # Extract texts from Dataset 2
521
+ # status = "Extracting texts from Dataset 2..."
522
+ # yield status, ""
523
+ # texts2 = [example[dataset2_text_column] for example in ds2]
524
+
525
+ # # Compute embeddings for Dataset 1
526
+ # status = "Computing embeddings for Dataset 1..."
527
+ # yield status, ""
528
+ # embedding_matrix1 = model.encode(texts1, show_progressbar=True)
529
+
530
+ # # Compute embeddings for Dataset 2
531
+ # status = "Computing embeddings for Dataset 2..."
532
+ # yield status, ""
533
+ # embedding_matrix2 = model.encode(texts2, show_progressbar=True)
534
+
535
+ # # Deduplicate across datasets
536
+ # status = "Deduplicating embeddings across datasets..."
537
+ # yield status, ""
538
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
539
+ # embedding_matrix1, embedding_matrix2, threshold
540
+ # )
541
+
542
+ # num_duplicates = len(duplicate_indices_in_ds2)
543
+ # num_total_ds2 = len(texts2)
544
+ # num_unique_ds2 = num_total_ds2 - num_duplicates
545
+
546
+ # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
547
+ # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
548
+ # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
549
+
550
+ # # Show deduplicated examples
551
+ # if num_duplicates > 0:
552
+ # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
553
+ # num_examples = min(5, num_duplicates)
554
+ # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
555
+ # original_idx = duplicate_to_original_mapping[duplicate_idx]
556
+ # original_text = texts1[original_idx]
557
+ # duplicate_text = texts2[duplicate_idx]
558
+ # differences = display_word_differences(original_text, duplicate_text)
559
+ # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
560
+ # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
561
+ # result_text += f"**Differences:**\n{differences}\n"
562
+ # result_text += "-" * 50 + "\n\n"
563
+ # else:
564
+ # result_text += "No duplicates found."
565
+
566
+ # # Final status
567
+ # status = "Deduplication completed."
568
+ # yield status, result_text
569
+
570
+ # finally:
571
+ # # Restore original tqdm
572
+ # tqdm.tqdm = original_tqdm
573
+ # for mod_name in list(sys.modules.keys()):
574
+ # if 'tqdm' in mod_name:
575
+ # sys.modules[mod_name].tqdm = original_tqdm
576
+
577
+ # with gr.Blocks() as demo:
578
+ # gr.Markdown("# Semantic Deduplication")
579
+
580
+ # deduplication_type = gr.Radio(
581
+ # choices=["Single dataset", "Cross-dataset"],
582
+ # label="Deduplication Type",
583
+ # value="Single dataset"
584
+ # )
585
+
586
+ # with gr.Row():
587
+ # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
588
+ # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
589
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
590
+
591
+ # dataset2_inputs = gr.Column(visible=False)
592
+ # with dataset2_inputs:
593
+ # gr.Markdown("### Dataset 2")
594
+ # with gr.Row():
595
+ # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
596
+ # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
597
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
598
+
599
+ # threshold = gr.Slider(
600
+ # minimum=0.0,
601
+ # maximum=1.0,
602
+ # value=default_threshold,
603
+ # label="Similarity Threshold"
604
+ # )
605
+
606
+ # compute_button = gr.Button("Compute")
607
+
608
+ # status_output = gr.Markdown()
609
+ # result_output = gr.Markdown()
610
+
611
+ # # Function to update the visibility of dataset2_inputs
612
+ # def update_visibility(deduplication_type_value):
613
+ # if deduplication_type_value == "Cross-dataset":
614
+ # return gr.update(visible=True)
615
+ # else:
616
+ # return gr.update(visible=False)
617
+
618
+ # deduplication_type.change(
619
+ # update_visibility,
620
+ # inputs=deduplication_type,
621
+ # outputs=dataset2_inputs
622
+ # )
623
+
624
+ # compute_button.click(
625
+ # fn=perform_deduplication,
626
+ # inputs=[
627
+ # deduplication_type,
628
+ # dataset1_name,
629
+ # dataset1_split,
630
+ # dataset1_text_column,
631
+ # dataset2_name,
632
+ # dataset2_split,
633
+ # dataset2_text_column,
634
+ # threshold
635
+ # ],
636
+ # outputs=[status_output, result_output]
637
+ # )
638
+
639
+ # demo.launch()
640
+
641
+
642
  # import gradio as gr
643
  # from datasets import load_dataset
644
  # import numpy as np