Pringled commited on
Commit
72c7e2c
1 Parent(s): 73a84b9
Files changed (1) hide show
  1. app.py +269 -56
app.py CHANGED
@@ -21,16 +21,7 @@ def deduplicate_embeddings(
21
  batch_size: int = 1024,
22
  progress=None
23
  ) -> tuple[np.ndarray, dict[int, int]]:
24
- """
25
- Deduplicate embeddings within one dataset or across two datasets.
26
-
27
- :param embeddings_a: Embeddings of Dataset 1.
28
- :param embeddings_b: Optional, embeddings of Dataset 2.
29
- :param threshold: Similarity threshold for deduplication.
30
- :param batch_size: Batch size for similarity computation.
31
- :param progress: Gradio progress tracker for feedback.
32
- :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
33
- """
34
  if embeddings_b is None:
35
  reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
36
  duplicate_to_original = {}
@@ -58,39 +49,13 @@ def deduplicate_embeddings(
58
  return duplicate_indices_in_b, duplicate_to_original
59
 
60
  def display_word_differences(x: str, y: str) -> str:
61
- """
62
- Display the word-level differences between two texts, formatted to avoid
63
- misinterpretation of Markdown syntax.
64
-
65
- :param x: First text.
66
- :param y: Second text.
67
- :return: A string showing word-level differences, wrapped in a code block.
68
- """
69
  diff = ndiff(x.split(), y.split())
70
- # Wrap differences in a code block to prevent interpretation as Markdown
71
  formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
72
  return f"```\n{formatted_diff}\n```"
73
 
74
- # def display_word_differences(x: str, y: str) -> str:
75
- # """
76
- # Display the word-level differences between two texts.
77
-
78
- # :param x: First text.
79
- # :param y: Second text.
80
- # :return: A string showing word-level differences.
81
- # """
82
- # diff = ndiff(x.split(), y.split())
83
- # return " ".join(word for word in diff if word.startswith(("+", "-")))
84
-
85
  def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
86
- """
87
- Load texts from a specified dataset and split.
88
-
89
- :param dataset_name: Name of the dataset.
90
- :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
91
- :param text_column: Name of the text column.
92
- :return: A list of texts from the dataset.
93
- """
94
  ds = load_dataset(dataset_name, split=dataset_split)
95
  return [example[text_column] for example in ds]
96
 
@@ -105,20 +70,7 @@ def perform_deduplication(
105
  threshold: float = default_threshold,
106
  progress: gr.Progress = gr.Progress(track_tqdm=True)
107
  ):
108
- """
109
- Perform deduplication on one or two datasets based on the deduplication type.
110
-
111
- :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
112
- :param dataset1_name: Name of the first dataset.
113
- :param dataset1_split: Split of the first dataset.
114
- :param dataset1_text_column: Text column of the first dataset.
115
- :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
116
- :param dataset2_split: Optional, split of the second dataset.
117
- :param dataset2_text_column: Optional, text column of the second dataset.
118
- :param threshold: Similarity threshold for deduplication.
119
- :param progress: Gradio progress tracker.
120
- :return: Status updates and result text for the Gradio interface.
121
- """
122
  try:
123
  threshold = float(threshold)
124
 
@@ -200,13 +152,13 @@ def perform_deduplication(
200
  yield f"An error occurred: {e}", ""
201
  raise e
202
 
 
203
  with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
204
  gr.Markdown("# Semantic Deduplication")
205
  gr.Markdown("""
206
- This demo showcases semantic deduplication using Model2Vec for HuggingFace datasets.
207
- It can be used to identify duplicate texts within a single dataset or across two datasets.
208
- You can adjust the similarity threshold to control the strictness of the deduplication.\n
209
- NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
210
  """)
211
 
212
  deduplication_type = gr.Radio(
@@ -230,6 +182,7 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
230
 
231
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
232
  compute_button = gr.Button("Compute")
 
233
  status_output = gr.Markdown(elem_id="status_output")
234
  result_output = gr.Markdown()
235
 
@@ -253,5 +206,265 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
253
  outputs=[status_output, result_output],
254
  )
255
 
 
 
 
256
  demo.launch()
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  batch_size: int = 1024,
22
  progress=None
23
  ) -> tuple[np.ndarray, dict[int, int]]:
24
+ """Deduplicate embeddings within one dataset or across two datasets."""
 
 
 
 
 
 
 
 
 
25
  if embeddings_b is None:
26
  reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
27
  duplicate_to_original = {}
 
49
  return duplicate_indices_in_b, duplicate_to_original
50
 
51
  def display_word_differences(x: str, y: str) -> str:
52
+ """Display word-level differences between two texts, avoiding Markdown issues."""
 
 
 
 
 
 
 
53
  diff = ndiff(x.split(), y.split())
 
54
  formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
55
  return f"```\n{formatted_diff}\n```"
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
58
+ """Load texts from a specified dataset and split."""
 
 
 
 
 
 
 
59
  ds = load_dataset(dataset_name, split=dataset_split)
60
  return [example[text_column] for example in ds]
61
 
 
70
  threshold: float = default_threshold,
71
  progress: gr.Progress = gr.Progress(track_tqdm=True)
72
  ):
73
+ """Perform deduplication on one or two datasets."""
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
  threshold = float(threshold)
76
 
 
152
  yield f"An error occurred: {e}", ""
153
  raise e
154
 
155
+ # Gradio app with stop button support
156
  with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
157
  gr.Markdown("# Semantic Deduplication")
158
  gr.Markdown("""
159
+ This demo showcases a semantic deduplication process where we identify duplicate texts within a single dataset or across two datasets.
160
+ The deduplication is based on cosine similarity between the embeddings of the texts.
161
+ You can adjust the similarity threshold to control the strictness of the deduplication.
 
162
  """)
163
 
164
  deduplication_type = gr.Radio(
 
182
 
183
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
184
  compute_button = gr.Button("Compute")
185
+ stop_button = gr.Button("Stop")
186
  status_output = gr.Markdown(elem_id="status_output")
187
  result_output = gr.Markdown()
188
 
 
206
  outputs=[status_output, result_output],
207
  )
208
 
209
+ # Stop button functionality
210
+ stop_button.click(lambda: demo.stop(), None, None)
211
+
212
  demo.launch()
213
 
214
+ # import gradio as gr
215
+ # from datasets import load_dataset
216
+ # import numpy as np
217
+ # from model2vec import StaticModel
218
+ # from reach import Reach
219
+ # from difflib import ndiff
220
+
221
+ # # Load the model
222
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
223
+
224
+ # # Default parameters
225
+ # default_dataset_name = "sst2"
226
+ # default_dataset_split = "train"
227
+ # default_text_column = "sentence"
228
+ # default_threshold = 0.9
229
+
230
+ # def deduplicate_embeddings(
231
+ # embeddings_a: np.ndarray,
232
+ # embeddings_b: np.ndarray = None,
233
+ # threshold: float = 0.9,
234
+ # batch_size: int = 1024,
235
+ # progress=None
236
+ # ) -> tuple[np.ndarray, dict[int, int]]:
237
+ # """
238
+ # Deduplicate embeddings within one dataset or across two datasets.
239
+
240
+ # :param embeddings_a: Embeddings of Dataset 1.
241
+ # :param embeddings_b: Optional, embeddings of Dataset 2.
242
+ # :param threshold: Similarity threshold for deduplication.
243
+ # :param batch_size: Batch size for similarity computation.
244
+ # :param progress: Gradio progress tracker for feedback.
245
+ # :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
246
+ # """
247
+ # if embeddings_b is None:
248
+ # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
249
+ # duplicate_to_original = {}
250
+ # results = reach.nearest_neighbor_threshold(
251
+ # embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
252
+ # )
253
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
254
+ # for sim_idx, _ in similar_items:
255
+ # sim_idx = int(sim_idx)
256
+ # if sim_idx != i and sim_idx not in duplicate_to_original:
257
+ # duplicate_to_original[sim_idx] = i
258
+ # deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
259
+ # return deduplicated_indices, duplicate_to_original
260
+ # else:
261
+ # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
262
+ # duplicate_indices_in_b = []
263
+ # duplicate_to_original = {}
264
+ # results = reach.nearest_neighbor_threshold(
265
+ # embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
266
+ # )
267
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
268
+ # if similar_items:
269
+ # duplicate_indices_in_b.append(i)
270
+ # duplicate_to_original[i] = int(similar_items[0][0])
271
+ # return duplicate_indices_in_b, duplicate_to_original
272
+
273
+ # def display_word_differences(x: str, y: str) -> str:
274
+ # """
275
+ # Display the word-level differences between two texts, formatted to avoid
276
+ # misinterpretation of Markdown syntax.
277
+
278
+ # :param x: First text.
279
+ # :param y: Second text.
280
+ # :return: A string showing word-level differences, wrapped in a code block.
281
+ # """
282
+ # diff = ndiff(x.split(), y.split())
283
+ # # Wrap differences in a code block to prevent interpretation as Markdown
284
+ # formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
285
+ # return f"```\n{formatted_diff}\n```"
286
+
287
+ # # def display_word_differences(x: str, y: str) -> str:
288
+ # # """
289
+ # # Display the word-level differences between two texts.
290
+
291
+ # # :param x: First text.
292
+ # # :param y: Second text.
293
+ # # :return: A string showing word-level differences.
294
+ # # """
295
+ # # diff = ndiff(x.split(), y.split())
296
+ # # return " ".join(word for word in diff if word.startswith(("+", "-")))
297
+
298
+ # def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
299
+ # """
300
+ # Load texts from a specified dataset and split.
301
+
302
+ # :param dataset_name: Name of the dataset.
303
+ # :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
304
+ # :param text_column: Name of the text column.
305
+ # :return: A list of texts from the dataset.
306
+ # """
307
+ # ds = load_dataset(dataset_name, split=dataset_split)
308
+ # return [example[text_column] for example in ds]
309
+
310
+ # def perform_deduplication(
311
+ # deduplication_type: str,
312
+ # dataset1_name: str,
313
+ # dataset1_split: str,
314
+ # dataset1_text_column: str,
315
+ # dataset2_name: str = "",
316
+ # dataset2_split: str = "",
317
+ # dataset2_text_column: str = "",
318
+ # threshold: float = default_threshold,
319
+ # progress: gr.Progress = gr.Progress(track_tqdm=True)
320
+ # ):
321
+ # """
322
+ # Perform deduplication on one or two datasets based on the deduplication type.
323
+
324
+ # :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
325
+ # :param dataset1_name: Name of the first dataset.
326
+ # :param dataset1_split: Split of the first dataset.
327
+ # :param dataset1_text_column: Text column of the first dataset.
328
+ # :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
329
+ # :param dataset2_split: Optional, split of the second dataset.
330
+ # :param dataset2_text_column: Optional, text column of the second dataset.
331
+ # :param threshold: Similarity threshold for deduplication.
332
+ # :param progress: Gradio progress tracker.
333
+ # :return: Status updates and result text for the Gradio interface.
334
+ # """
335
+ # try:
336
+ # threshold = float(threshold)
337
+
338
+ # # Load and process Dataset 1
339
+ # yield "Loading Dataset 1...", ""
340
+ # texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
341
+ # yield "Computing embeddings for Dataset 1...", ""
342
+ # embeddings1 = model.encode(texts1, show_progressbar=True)
343
+
344
+ # if deduplication_type == "Single dataset":
345
+ # # Deduplicate within Dataset 1
346
+ # yield "Deduplicating within Dataset 1...", ""
347
+ # deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
348
+ # embeddings1, threshold=threshold, progress=progress
349
+ # )
350
+
351
+ # num_duplicates = len(duplicate_mapping)
352
+ # result_text = (
353
+ # f"**Total documents:** {len(texts1)}\n\n"
354
+ # f"**Duplicates found:** {num_duplicates}\n\n"
355
+ # f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
356
+ # )
357
+
358
+ # if num_duplicates > 0:
359
+ # result_text += "**Sample duplicates:**\n\n"
360
+ # for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
361
+ # orig_text = texts1[orig_idx]
362
+ # dup_text = texts1[dup_idx]
363
+ # differences = display_word_differences(orig_text, dup_text)
364
+ # result_text += (
365
+ # f"**Original:**\n{orig_text}\n\n"
366
+ # f"**Duplicate:**\n{dup_text}\n\n"
367
+ # f"**Differences:**\n{differences}\n"
368
+ # + "-" * 50 + "\n\n"
369
+ # )
370
+ # else:
371
+ # result_text += "No duplicates found."
372
+
373
+ # yield "Deduplication completed.", result_text
374
+
375
+ # else:
376
+ # # Load and process Dataset 2
377
+ # yield "Loading Dataset 2...", ""
378
+ # texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
379
+ # yield "Computing embeddings for Dataset 2...", ""
380
+ # embeddings2 = model.encode(texts2, show_progressbar=True)
381
+
382
+ # # Deduplicate Dataset 2 against Dataset 1
383
+ # yield "Deduplicating Dataset 2 against Dataset 1...", ""
384
+ # duplicate_indices, duplicate_mapping = deduplicate_embeddings(
385
+ # embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
386
+ # )
387
+
388
+ # num_duplicates = len(duplicate_indices)
389
+ # result_text = (
390
+ # f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
391
+ # f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
392
+ # f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
393
+ # )
394
+
395
+ # if num_duplicates > 0:
396
+ # result_text += "**Sample duplicates from Dataset 2:**\n\n"
397
+ # for idx in duplicate_indices[:5]:
398
+ # orig_text = texts1[duplicate_mapping[idx]]
399
+ # dup_text = texts2[idx]
400
+ # differences = display_word_differences(orig_text, dup_text)
401
+ # result_text += (
402
+ # f"**Original (Dataset 1):**\n{orig_text}\n\n"
403
+ # f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
404
+ # f"**Differences:**\n{differences}\n"
405
+ # + "-" * 50 + "\n\n"
406
+ # )
407
+ # else:
408
+ # result_text += "No duplicates found."
409
+
410
+ # yield "Deduplication completed.", result_text
411
+
412
+ # except Exception as e:
413
+ # yield f"An error occurred: {e}", ""
414
+ # raise e
415
+
416
+ # with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
417
+ # gr.Markdown("# Semantic Deduplication")
418
+ # gr.Markdown("""
419
+ # This demo showcases semantic deduplication using Model2Vec for HuggingFace datasets.
420
+ # It can be used to identify duplicate texts within a single dataset or across two datasets.
421
+ # You can adjust the similarity threshold to control the strictness of the deduplication.\n
422
+ # NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
423
+ # """)
424
+
425
+ # deduplication_type = gr.Radio(
426
+ # choices=["Single dataset", "Cross-dataset"],
427
+ # label="Deduplication Type",
428
+ # value="Single dataset",
429
+ # )
430
+
431
+ # with gr.Row():
432
+ # dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
433
+ # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
434
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
435
+
436
+ # dataset2_inputs = gr.Column(visible=False)
437
+ # with dataset2_inputs:
438
+ # gr.Markdown("### Dataset 2")
439
+ # with gr.Row():
440
+ # dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
441
+ # dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
442
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
443
+
444
+ # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
445
+ # compute_button = gr.Button("Compute")
446
+ # status_output = gr.Markdown(elem_id="status_output")
447
+ # result_output = gr.Markdown()
448
+
449
+ # def update_visibility(choice: str):
450
+ # return gr.update(visible=choice == "Cross-dataset")
451
+
452
+ # deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
453
+
454
+ # compute_button.click(
455
+ # fn=perform_deduplication,
456
+ # inputs=[
457
+ # deduplication_type,
458
+ # dataset1_name,
459
+ # dataset1_split,
460
+ # dataset1_text_column,
461
+ # dataset2_name,
462
+ # dataset2_split,
463
+ # dataset2_text_column,
464
+ # threshold,
465
+ # ],
466
+ # outputs=[status_output, result_output],
467
+ # )
468
+
469
+ # demo.launch()
470
+