Pringled commited on
Commit
f5eb405
1 Parent(s): 7ed3881

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +366 -101
app.py CHANGED
@@ -4,67 +4,74 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
 
 
7
 
8
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[np.ndarray, dict[int, int]]:
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
11
  """
12
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
13
 
14
- # Use a set for deduplicated indices and keep track of duplicates
15
- deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
16
  duplicate_to_original_mapping = {}
17
 
18
  results = reach.nearest_neighbor_threshold(
19
- embedding_matrix,
20
- threshold=threshold,
21
- batch_size=batch_size,
22
- show_progressbar=True
23
  )
24
 
25
  # Process duplicates
26
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
27
  if i not in deduplicated_indices:
28
- continue # Skip already marked duplicates
29
 
30
- # Similar items are returned as (index, score), we are only interested in the index
31
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
32
 
33
- # Mark similar documents as duplicates and map them to the original
34
  for sim_idx in similar_indices:
35
  if sim_idx in deduplicated_indices:
36
  deduplicated_indices.remove(sim_idx)
37
- duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
38
 
39
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
40
 
41
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[list[int], dict[int, int]]:
42
  """
43
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
44
  """
45
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
46
 
47
- # Keep track of duplicates in the second dataset
48
  duplicate_indices_in_test = []
49
  duplicate_to_original_mapping = {}
50
 
51
- # Find nearest neighbors from the test set in the train set
52
  results = reach.nearest_neighbor_threshold(
53
- embedding_matrix_2,
54
- threshold=threshold,
55
- batch_size=batch_size,
56
- show_progressbar=True
57
  )
58
 
59
  # Process duplicates
60
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
61
- # Similar items are returned as (index, score), we are only interested in the index
62
- similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
63
 
64
- # If we find a similar item in the train set, mark it as a duplicate
65
  if similar_indices:
66
  duplicate_indices_in_test.append(i)
67
- duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
68
 
69
  return duplicate_indices_in_test, duplicate_to_original_mapping
70
 
@@ -83,85 +90,114 @@ def perform_deduplication(
83
  threshold=0.8,
84
  progress=gr.Progress(track_tqdm=True)
85
  ):
86
- # Convert threshold to float
87
- threshold = float(threshold)
88
-
89
- if deduplication_type == "Single dataset":
90
- # Load the dataset
91
- ds = load_dataset(dataset1_name, split=dataset1_split)
92
-
93
- # Extract texts
94
- texts = [example[dataset1_text_column] for example in ds]
95
-
96
- # Compute embeddings
97
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
98
- embedding_matrix = model.encode(texts, show_progressbar=True)
99
-
100
- # Deduplicate
101
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
102
-
103
- # Prepare the results
104
- num_duplicates = len(duplicate_to_original_mapping)
105
- num_total = len(texts)
106
- num_deduplicated = len(deduplicated_indices)
107
-
108
- result_text = f"**Total documents:** {num_total}\n"
109
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
110
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
111
-
112
- # Show deduplicated examples
113
- result_text += "**Examples of duplicates found:**\n\n"
114
- num_examples = min(5, num_duplicates)
115
- for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
116
- original_text = texts[original_idx]
117
- duplicate_text = texts[duplicate_idx]
118
- differences = display_word_differences(original_text, duplicate_text)
119
- result_text += f"**Original text:**\n{original_text}\n\n"
120
- result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
121
- result_text += f"**Differences:**\n{differences}\n"
122
- result_text += "-" * 50 + "\n\n"
123
-
124
- return result_text
125
-
126
- elif deduplication_type == "Cross-dataset":
127
- # Load datasets
128
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
129
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
130
-
131
- # Extract texts
132
- texts1 = [example[dataset1_text_column] for example in ds1]
133
- texts2 = [example[dataset2_text_column] for example in ds2]
134
-
135
- # Compute embeddings
136
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
137
- embedding_matrix1 = model.encode(texts1, show_progressbar=True)
138
- embedding_matrix2 = model.encode(texts2, show_progressbar=True)
139
-
140
- # Deduplicate across datasets
141
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
142
-
143
- num_duplicates = len(duplicate_indices_in_ds2)
144
- num_total_ds2 = len(texts2)
145
- num_unique_ds2 = num_total_ds2 - num_duplicates
146
-
147
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
148
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
149
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
150
-
151
- # Show deduplicated examples
152
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
153
- num_examples = min(5, num_duplicates)
154
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
155
- original_idx = duplicate_to_original_mapping[duplicate_idx]
156
- original_text = texts1[original_idx]
157
- duplicate_text = texts2[duplicate_idx]
158
- differences = display_word_differences(original_text, duplicate_text)
159
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
160
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
161
- result_text += f"**Differences:**\n{differences}\n"
162
- result_text += "-" * 50 + "\n\n"
163
-
164
- return result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  with gr.Blocks() as demo:
167
  gr.Markdown("# Semantic Deduplication")
@@ -225,3 +261,232 @@ with gr.Blocks() as demo:
225
  )
226
 
227
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
+ import sys
8
+ import tqdm
9
 
10
+ # Load the model at startup
11
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
12
+
13
+ # Load the default datasets at startup
14
+ default_dataset1_name = "ag_news"
15
+ default_dataset1_split = "train"
16
+ default_dataset2_name = "ag_news"
17
+ default_dataset2_split = "test"
18
+
19
+ ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
20
+ ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
21
+
22
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
23
  """
24
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
25
  """
26
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
27
 
28
+ deduplicated_indices = set(range(len(embedding_matrix)))
 
29
  duplicate_to_original_mapping = {}
30
 
31
  results = reach.nearest_neighbor_threshold(
32
+ embedding_matrix,
33
+ threshold=threshold,
34
+ batch_size=batch_size,
35
+ show_progressbar=False # Disable internal progress bar
36
  )
37
 
38
  # Process duplicates
39
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
40
  if i not in deduplicated_indices:
41
+ continue
42
 
 
43
  similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
44
 
 
45
  for sim_idx in similar_indices:
46
  if sim_idx in deduplicated_indices:
47
  deduplicated_indices.remove(sim_idx)
48
+ duplicate_to_original_mapping[sim_idx] = i
49
 
50
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
51
 
52
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
53
  """
54
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
55
  """
56
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
57
 
 
58
  duplicate_indices_in_test = []
59
  duplicate_to_original_mapping = {}
60
 
 
61
  results = reach.nearest_neighbor_threshold(
62
+ embedding_matrix_2,
63
+ threshold=threshold,
64
+ batch_size=batch_size,
65
+ show_progressbar=False # Disable internal progress bar
66
  )
67
 
68
  # Process duplicates
69
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
70
+ similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
 
71
 
 
72
  if similar_indices:
73
  duplicate_indices_in_test.append(i)
74
+ duplicate_to_original_mapping[i] = similar_indices[0]
75
 
76
  return duplicate_indices_in_test, duplicate_to_original_mapping
77
 
 
90
  threshold=0.8,
91
  progress=gr.Progress(track_tqdm=True)
92
  ):
93
+ # Monkey-patch tqdm
94
+ original_tqdm = tqdm.tqdm
95
+ tqdm.tqdm = progress.tqdm
96
+ sys.modules['tqdm'].tqdm = progress.tqdm
97
+ sys.modules['tqdm.auto'].tqdm = progress.tqdm
98
+
99
+ try:
100
+ # Convert threshold to float
101
+ threshold = float(threshold)
102
+
103
+ if deduplication_type == "Single dataset":
104
+ # Check if the dataset is the default one
105
+ if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
106
+ ds = ds_default1
107
+ else:
108
+ ds = load_dataset(dataset1_name, split=dataset1_split)
109
+
110
+ # Extract texts
111
+ texts = [example[dataset1_text_column] for example in ds]
112
+
113
+ # Compute embeddings
114
+ embedding_matrix = model.encode(texts, show_progressbar=False) # Disable internal progress bar
115
+
116
+ # Show progress bar for embedding computation
117
+ embedding_matrix = progress.tqdm(embedding_matrix, desc="Computing embeddings")
118
+
119
+ # Deduplicate
120
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
121
+
122
+ # Prepare the results
123
+ num_duplicates = len(duplicate_to_original_mapping)
124
+ num_total = len(texts)
125
+ num_deduplicated = len(deduplicated_indices)
126
+
127
+ result_text = f"**Total documents:** {num_total}\n"
128
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
129
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
130
+
131
+ # Show deduplicated examples
132
+ result_text += "**Examples of duplicates found:**\n\n"
133
+ num_examples = min(5, num_duplicates)
134
+ for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
135
+ original_text = texts[original_idx]
136
+ duplicate_text = texts[duplicate_idx]
137
+ differences = display_word_differences(original_text, duplicate_text)
138
+ result_text += f"**Original text:**\n{original_text}\n\n"
139
+ result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
140
+ result_text += f"**Differences:**\n{differences}\n"
141
+ result_text += "-" * 50 + "\n\n"
142
+
143
+ return result_text
144
+
145
+ elif deduplication_type == "Cross-dataset":
146
+ # Dataset 1
147
+ if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
148
+ ds1 = ds_default1
149
+ else:
150
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
151
+
152
+ # Dataset 2
153
+ if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
154
+ ds2 = ds_default2
155
+ else:
156
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
157
+
158
+ # Extract texts
159
+ texts1 = [example[dataset1_text_column] for example in ds1]
160
+ texts2 = [example[dataset2_text_column] for example in ds2]
161
+
162
+ # Compute embeddings
163
+ embedding_matrix1 = model.encode(texts1, show_progressbar=False) # Disable internal progress bar
164
+ embedding_matrix2 = model.encode(texts2, show_progressbar=False) # Disable internal progress bar
165
+
166
+ # Show progress bar for embedding computation
167
+ embedding_matrix1 = progress.tqdm(embedding_matrix1, desc="Computing embeddings for Dataset 1")
168
+ embedding_matrix2 = progress.tqdm(embedding_matrix2, desc="Computing embeddings for Dataset 2")
169
+
170
+ # Deduplicate across datasets
171
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
172
+
173
+ num_duplicates = len(duplicate_indices_in_ds2)
174
+ num_total_ds2 = len(texts2)
175
+ num_unique_ds2 = num_total_ds2 - num_duplicates
176
+
177
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
178
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
179
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
180
+
181
+ # Show deduplicated examples
182
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
183
+ num_examples = min(5, num_duplicates)
184
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
185
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
186
+ original_text = texts1[original_idx]
187
+ duplicate_text = texts2[duplicate_idx]
188
+ differences = display_word_differences(original_text, duplicate_text)
189
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
190
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
191
+ result_text += f"**Differences:**\n{differences}\n"
192
+ result_text += "-" * 50 + "\n\n"
193
+
194
+ return result_text
195
+
196
+ finally:
197
+ # Restore original tqdm
198
+ tqdm.tqdm = original_tqdm
199
+ sys.modules['tqdm'].tqdm = original_tqdm
200
+ sys.modules['tqdm.auto'].tqdm = original_tqdm
201
 
202
  with gr.Blocks() as demo:
203
  gr.Markdown("# Semantic Deduplication")
 
261
  )
262
 
263
  demo.launch()
264
+
265
+
266
+ # import gradio as gr
267
+ # from datasets import load_dataset
268
+ # import numpy as np
269
+ # from model2vec import StaticModel
270
+ # from reach import Reach
271
+ # from difflib import ndiff
272
+
273
+ # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[np.ndarray, dict[int, int]]:
274
+ # """
275
+ # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
276
+ # """
277
+ # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
278
+
279
+ # # Use a set for deduplicated indices and keep track of duplicates
280
+ # deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
281
+ # duplicate_to_original_mapping = {}
282
+
283
+ # results = reach.nearest_neighbor_threshold(
284
+ # embedding_matrix,
285
+ # threshold=threshold,
286
+ # batch_size=batch_size,
287
+ # show_progressbar=True
288
+ # )
289
+
290
+ # # Process duplicates
291
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
292
+ # if i not in deduplicated_indices:
293
+ # continue # Skip already marked duplicates
294
+
295
+ # # Similar items are returned as (index, score), we are only interested in the index
296
+ # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
297
+
298
+ # # Mark similar documents as duplicates and map them to the original
299
+ # for sim_idx in similar_indices:
300
+ # if sim_idx in deduplicated_indices:
301
+ # deduplicated_indices.remove(sim_idx)
302
+ # duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
303
+
304
+ # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
305
+
306
+ # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[list[int], dict[int, int]]:
307
+ # """
308
+ # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
309
+ # """
310
+ # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
311
+
312
+ # # Keep track of duplicates in the second dataset
313
+ # duplicate_indices_in_test = []
314
+ # duplicate_to_original_mapping = {}
315
+
316
+ # # Find nearest neighbors from the test set in the train set
317
+ # results = reach.nearest_neighbor_threshold(
318
+ # embedding_matrix_2,
319
+ # threshold=threshold,
320
+ # batch_size=batch_size,
321
+ # show_progressbar=True
322
+ # )
323
+
324
+ # # Process duplicates
325
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
326
+ # # Similar items are returned as (index, score), we are only interested in the index
327
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
328
+
329
+ # # If we find a similar item in the train set, mark it as a duplicate
330
+ # if similar_indices:
331
+ # duplicate_indices_in_test.append(i)
332
+ # duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
333
+
334
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
335
+
336
+ # def display_word_differences(x: str, y: str) -> str:
337
+ # diff = ndiff(x.split(), y.split())
338
+ # return " ".join([word for word in diff if word.startswith(('+', '-'))])
339
+
340
+ # def perform_deduplication(
341
+ # deduplication_type,
342
+ # dataset1_name,
343
+ # dataset1_split,
344
+ # dataset1_text_column,
345
+ # dataset2_name="",
346
+ # dataset2_split="",
347
+ # dataset2_text_column="",
348
+ # threshold=0.8,
349
+ # progress=gr.Progress(track_tqdm=True)
350
+ # ):
351
+ # # Convert threshold to float
352
+ # threshold = float(threshold)
353
+
354
+ # if deduplication_type == "Single dataset":
355
+ # # Load the dataset
356
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
357
+
358
+ # # Extract texts
359
+ # texts = [example[dataset1_text_column] for example in ds]
360
+
361
+ # # Compute embeddings
362
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
363
+ # embedding_matrix = model.encode(texts, show_progressbar=True)
364
+
365
+ # # Deduplicate
366
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
367
+
368
+ # # Prepare the results
369
+ # num_duplicates = len(duplicate_to_original_mapping)
370
+ # num_total = len(texts)
371
+ # num_deduplicated = len(deduplicated_indices)
372
+
373
+ # result_text = f"**Total documents:** {num_total}\n"
374
+ # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
375
+ # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
376
+
377
+ # # Show deduplicated examples
378
+ # result_text += "**Examples of duplicates found:**\n\n"
379
+ # num_examples = min(5, num_duplicates)
380
+ # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
381
+ # original_text = texts[original_idx]
382
+ # duplicate_text = texts[duplicate_idx]
383
+ # differences = display_word_differences(original_text, duplicate_text)
384
+ # result_text += f"**Original text:**\n{original_text}\n\n"
385
+ # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
386
+ # result_text += f"**Differences:**\n{differences}\n"
387
+ # result_text += "-" * 50 + "\n\n"
388
+
389
+ # return result_text
390
+
391
+ # elif deduplication_type == "Cross-dataset":
392
+ # # Load datasets
393
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
394
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
395
+
396
+ # # Extract texts
397
+ # texts1 = [example[dataset1_text_column] for example in ds1]
398
+ # texts2 = [example[dataset2_text_column] for example in ds2]
399
+
400
+ # # Compute embeddings
401
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
402
+ # embedding_matrix1 = model.encode(texts1, show_progressbar=True)
403
+ # embedding_matrix2 = model.encode(texts2, show_progressbar=True)
404
+
405
+ # # Deduplicate across datasets
406
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
407
+
408
+ # num_duplicates = len(duplicate_indices_in_ds2)
409
+ # num_total_ds2 = len(texts2)
410
+ # num_unique_ds2 = num_total_ds2 - num_duplicates
411
+
412
+ # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
413
+ # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
414
+ # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
415
+
416
+ # # Show deduplicated examples
417
+ # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
418
+ # num_examples = min(5, num_duplicates)
419
+ # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
420
+ # original_idx = duplicate_to_original_mapping[duplicate_idx]
421
+ # original_text = texts1[original_idx]
422
+ # duplicate_text = texts2[duplicate_idx]
423
+ # differences = display_word_differences(original_text, duplicate_text)
424
+ # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
425
+ # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
426
+ # result_text += f"**Differences:**\n{differences}\n"
427
+ # result_text += "-" * 50 + "\n\n"
428
+
429
+ # return result_text
430
+
431
+ # with gr.Blocks() as demo:
432
+ # gr.Markdown("# Semantic Deduplication")
433
+
434
+ # deduplication_type = gr.Radio(
435
+ # choices=["Single dataset", "Cross-dataset"],
436
+ # label="Deduplication Type",
437
+ # value="Single dataset"
438
+ # )
439
+
440
+ # with gr.Row():
441
+ # dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
442
+ # dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
443
+ # dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
444
+
445
+ # dataset2_inputs = gr.Column(visible=False)
446
+ # with dataset2_inputs:
447
+ # gr.Markdown("### Dataset 2")
448
+ # with gr.Row():
449
+ # dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
450
+ # dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
451
+ # dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
452
+
453
+ # threshold = gr.Slider(
454
+ # minimum=0.0,
455
+ # maximum=1.0,
456
+ # value=0.8,
457
+ # label="Similarity Threshold"
458
+ # )
459
+
460
+ # compute_button = gr.Button("Compute")
461
+
462
+ # output = gr.Markdown()
463
+
464
+ # # Function to update the visibility of dataset2_inputs
465
+ # def update_visibility(deduplication_type_value):
466
+ # if deduplication_type_value == "Cross-dataset":
467
+ # return gr.update(visible=True)
468
+ # else:
469
+ # return gr.update(visible=False)
470
+
471
+ # deduplication_type.change(
472
+ # update_visibility,
473
+ # inputs=deduplication_type,
474
+ # outputs=dataset2_inputs
475
+ # )
476
+
477
+ # compute_button.click(
478
+ # fn=perform_deduplication,
479
+ # inputs=[
480
+ # deduplication_type,
481
+ # dataset1_name,
482
+ # dataset1_split,
483
+ # dataset1_text_column,
484
+ # dataset2_name,
485
+ # dataset2_split,
486
+ # dataset2_text_column,
487
+ # threshold
488
+ # ],
489
+ # outputs=output
490
+ # )
491
+
492
+ # demo.launch()