Pringled commited on
Commit
1a5f99b
1 Parent(s): 39a5b1c
Files changed (1) hide show
  1. app.py +330 -87
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  from datasets import load_dataset
4
  import numpy as np
@@ -38,23 +37,18 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
38
  """
39
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
40
  """
41
- # Building the index
42
- progress(0, desc="Building search index...")
43
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
44
 
45
  deduplicated_indices = set(range(len(embedding_matrix)))
46
  duplicate_to_original_mapping = {}
47
 
48
- # Finding nearest neighbors
49
- progress(0, desc="Finding nearest neighbors...")
50
  results = reach.nearest_neighbor_threshold(
51
  embedding_matrix,
52
  threshold=threshold,
53
  batch_size=batch_size,
54
- show_progressbar=False # Disable internal progress bar
55
  )
56
 
57
- # Processing duplicates with a progress bar
58
  total_items = len(embedding_matrix)
59
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
60
  if i not in deduplicated_indices:
@@ -73,24 +67,19 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
73
  """
74
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
75
  """
76
- # Building the index from Dataset 1
77
- progress(0, desc="Building search index from Dataset 1...")
78
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
79
 
80
  duplicate_indices_in_test = []
81
  duplicate_to_original_mapping = {}
82
 
83
- # Finding nearest neighbors between datasets
84
- progress(0, desc="Finding nearest neighbors between datasets...")
85
  results = reach.nearest_neighbor_threshold(
86
  embedding_matrix_2,
87
  threshold=threshold,
88
  batch_size=batch_size,
89
- show_progressbar=False # Disable internal progress bar
90
  )
91
 
92
  total_items = len(embedding_matrix_2)
93
- # Processing duplicates with a progress bar
94
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
95
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
96
 
@@ -116,39 +105,15 @@ def perform_deduplication(
116
  progress=gr.Progress(track_tqdm=True)
117
  ):
118
  try:
119
- # Convert threshold to float
120
  threshold = float(threshold)
121
 
122
- # Initialize status message
123
- status = ""
124
-
125
  if deduplication_type == "Single dataset":
126
- # Load Dataset 1
127
- status = "Loading Dataset 1..."
128
- yield status, ""
129
- if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
130
- ds = ds_default1
131
- else:
132
- ds = load_dataset(dataset1_name, split=dataset1_split)
133
-
134
- # Extract texts
135
- status = "Extracting texts from Dataset 1..."
136
- yield status, ""
137
  texts = [example[dataset1_text_column] for example in ds]
138
 
139
- # Compute embeddings
140
- status = "Computing embeddings for Dataset 1..."
141
- yield status, ""
142
  embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
143
 
144
- # Deduplicate
145
- status = "Deduplicating embeddings..."
146
- yield status, ""
147
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(
148
- embedding_matrix, threshold, progress=progress
149
- )
150
-
151
- # Prepare the results
152
  num_duplicates = len(duplicate_to_original_mapping)
153
  num_total = len(texts)
154
  num_deduplicated = len(deduplicated_indices)
@@ -157,7 +122,6 @@ def perform_deduplication(
157
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
158
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
159
 
160
- # Show deduplicated examples
161
  if num_duplicates > 0:
162
  result_text += "**Examples of duplicates found:**\n\n"
163
  num_examples = min(5, num_duplicates)
@@ -172,53 +136,19 @@ def perform_deduplication(
172
  else:
173
  result_text += "No duplicates found."
174
 
175
- # Final status
176
- status = "Deduplication completed."
177
- yield status, result_text
178
 
179
  elif deduplication_type == "Cross-dataset":
180
- # Load Dataset 1
181
- status = "Loading Dataset 1..."
182
- yield status, ""
183
- if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
184
- ds1 = ds_default1
185
- else:
186
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
187
 
188
- # Load Dataset 2
189
- status = "Loading Dataset 2..."
190
- yield status, ""
191
- if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
192
- ds2 = ds_default2
193
- else:
194
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
195
-
196
- # Extract texts from Dataset 1
197
- status = "Extracting texts from Dataset 1..."
198
- yield status, ""
199
  texts1 = [example[dataset1_text_column] for example in ds1]
200
-
201
- # Extract texts from Dataset 2
202
- status = "Extracting texts from Dataset 2..."
203
- yield status, ""
204
  texts2 = [example[dataset2_text_column] for example in ds2]
205
 
206
- # Compute embeddings for Dataset 1
207
- status = "Computing embeddings for Dataset 1..."
208
- yield status, ""
209
  embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
210
-
211
- # Compute embeddings for Dataset 2
212
- status = "Computing embeddings for Dataset 2..."
213
- yield status, ""
214
  embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
215
 
216
- # Deduplicate across datasets
217
- status = "Deduplicating embeddings across datasets..."
218
- yield status, ""
219
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
220
- embedding_matrix1, embedding_matrix2, threshold, progress=progress
221
- )
222
 
223
  num_duplicates = len(duplicate_indices_in_ds2)
224
  num_total_ds2 = len(texts2)
@@ -228,7 +158,6 @@ def perform_deduplication(
228
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
229
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
230
 
231
- # Show deduplicated examples
232
  if num_duplicates > 0:
233
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
234
  num_examples = min(5, num_duplicates)
@@ -244,15 +173,13 @@ def perform_deduplication(
244
  else:
245
  result_text += "No duplicates found."
246
 
247
- # Final status
248
- status = "Deduplication completed."
249
- yield status, result_text
250
 
251
  except Exception as e:
252
  yield f"An error occurred: {e}", ""
253
- raise e
254
 
255
- with gr.Blocks() as demo:
 
256
  gr.Markdown("# Semantic Deduplication")
257
 
258
  deduplication_type = gr.Radio(
@@ -283,10 +210,9 @@ with gr.Blocks() as demo:
283
 
284
  compute_button = gr.Button("Compute")
285
 
286
- status_output = gr.Markdown()
287
- result_output = gr.Markdown()
288
 
289
- # Function to update the visibility of dataset2_inputs
290
  def update_visibility(deduplication_type_value):
291
  if deduplication_type_value == "Cross-dataset":
292
  return gr.update(visible=True)
@@ -316,6 +242,323 @@ with gr.Blocks() as demo:
316
 
317
  demo.launch()
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  # import gradio as gr
321
  # from datasets import load_dataset
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
 
37
  """
38
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
39
  """
 
 
40
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
41
 
42
  deduplicated_indices = set(range(len(embedding_matrix)))
43
  duplicate_to_original_mapping = {}
44
 
 
 
45
  results = reach.nearest_neighbor_threshold(
46
  embedding_matrix,
47
  threshold=threshold,
48
  batch_size=batch_size,
49
+ show_progressbar=False
50
  )
51
 
 
52
  total_items = len(embedding_matrix)
53
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
54
  if i not in deduplicated_indices:
 
67
  """
68
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
69
  """
 
 
70
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
71
 
72
  duplicate_indices_in_test = []
73
  duplicate_to_original_mapping = {}
74
 
 
 
75
  results = reach.nearest_neighbor_threshold(
76
  embedding_matrix_2,
77
  threshold=threshold,
78
  batch_size=batch_size,
79
+ show_progressbar=False
80
  )
81
 
82
  total_items = len(embedding_matrix_2)
 
83
  for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
84
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
85
 
 
105
  progress=gr.Progress(track_tqdm=True)
106
  ):
107
  try:
 
108
  threshold = float(threshold)
109
 
 
 
 
110
  if deduplication_type == "Single dataset":
111
+ ds = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
 
 
 
 
 
 
 
 
 
 
112
  texts = [example[dataset1_text_column] for example in ds]
113
 
 
 
 
114
  embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
115
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
116
 
 
 
 
 
 
 
 
 
117
  num_duplicates = len(duplicate_to_original_mapping)
118
  num_total = len(texts)
119
  num_deduplicated = len(deduplicated_indices)
 
122
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
123
  result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
124
 
 
125
  if num_duplicates > 0:
126
  result_text += "**Examples of duplicates found:**\n\n"
127
  num_examples = min(5, num_duplicates)
 
136
  else:
137
  result_text += "No duplicates found."
138
 
139
+ yield result_text
 
 
140
 
141
  elif deduplication_type == "Cross-dataset":
142
+ ds1 = ds_default1 if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split else load_dataset(dataset1_name, split=dataset1_split)
143
+ ds2 = ds_default2 if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split else load_dataset(dataset2_name, split=dataset2_split)
 
 
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
145
  texts1 = [example[dataset1_text_column] for example in ds1]
 
 
 
 
146
  texts2 = [example[dataset2_text_column] for example in ds2]
147
 
 
 
 
148
  embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
 
 
 
149
  embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
150
 
151
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
 
 
 
 
152
 
153
  num_duplicates = len(duplicate_indices_in_ds2)
154
  num_total_ds2 = len(texts2)
 
158
  result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
159
  result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
160
 
 
161
  if num_duplicates > 0:
162
  result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
163
  num_examples = min(5, num_duplicates)
 
173
  else:
174
  result_text += "No duplicates found."
175
 
176
+ yield result_text
 
 
177
 
178
  except Exception as e:
179
  yield f"An error occurred: {e}", ""
 
180
 
181
+ # Adjust the height of the status_output and result_output components
182
+ with gr.Blocks(css="#status_output { height: 300px; overflow: auto; } #result_output { height: 300px; overflow: auto; }") as demo:
183
  gr.Markdown("# Semantic Deduplication")
184
 
185
  deduplication_type = gr.Radio(
 
210
 
211
  compute_button = gr.Button("Compute")
212
 
213
+ status_output = gr.Markdown(elem_id="status_output")
214
+ result_output = gr.Markdown(elem_id="result_output")
215
 
 
216
  def update_visibility(deduplication_type_value):
217
  if deduplication_type_value == "Cross-dataset":
218
  return gr.update(visible=True)
 
242
 
243
  demo.launch()
244
 
245
+ # import gradio as gr
246
+ # from datasets import load_dataset
247
+ # import numpy as np
248
+ # from model2vec import StaticModel
249
+ # from reach import Reach
250
+ # from difflib import ndiff
251
+ # import tqdm
252
+
253
+ # # Load the model at startup
254
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
255
+
256
+ # # Update default dataset to 'sst2' and set default threshold to 0.9
257
+ # default_dataset1_name = "sst2"
258
+ # default_dataset1_split = "train"
259
+ # default_dataset2_name = "sst2"
260
+ # default_dataset2_split = "validation"
261
+ # default_text_column = "sentence"
262
+ # default_threshold = 0.9
263
+
264
+ # # Load the default datasets at startup
265
+ # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
266
+ # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
267
+
268
+ # def batch_iterable(iterable, batch_size):
269
+ # """Helper function to create batches from an iterable."""
270
+ # for i in range(0, len(iterable), batch_size):
271
+ # yield iterable[i:i + batch_size]
272
+
273
+ # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
274
+ # embeddings = []
275
+ # for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
276
+ # batch_embeddings = model.encode(batch, show_progressbar=False)
277
+ # embeddings.append(batch_embeddings)
278
+ # return np.concatenate(embeddings, axis=0)
279
+
280
+ # def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
281
+ # """
282
+ # Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
283
+ # """
284
+ # # Building the index
285
+ # progress(0, desc="Building search index...")
286
+ # reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
287
+
288
+ # deduplicated_indices = set(range(len(embedding_matrix)))
289
+ # duplicate_to_original_mapping = {}
290
+
291
+ # # Finding nearest neighbors
292
+ # progress(0, desc="Finding nearest neighbors...")
293
+ # results = reach.nearest_neighbor_threshold(
294
+ # embedding_matrix,
295
+ # threshold=threshold,
296
+ # batch_size=batch_size,
297
+ # show_progressbar=False # Disable internal progress bar
298
+ # )
299
+
300
+ # # Processing duplicates with a progress bar
301
+ # total_items = len(embedding_matrix)
302
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
303
+ # if i not in deduplicated_indices:
304
+ # continue
305
+
306
+ # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
307
+
308
+ # for sim_idx in similar_indices:
309
+ # if sim_idx in deduplicated_indices:
310
+ # deduplicated_indices.remove(sim_idx)
311
+ # duplicate_to_original_mapping[sim_idx] = i
312
+
313
+ # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
314
+
315
+ # def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
316
+ # """
317
+ # Deduplicate embeddings across two datasets and return the indices of duplicates between them.
318
+ # """
319
+ # # Building the index from Dataset 1
320
+ # progress(0, desc="Building search index from Dataset 1...")
321
+ # reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
322
+
323
+ # duplicate_indices_in_test = []
324
+ # duplicate_to_original_mapping = {}
325
+
326
+ # # Finding nearest neighbors between datasets
327
+ # progress(0, desc="Finding nearest neighbors between datasets...")
328
+ # results = reach.nearest_neighbor_threshold(
329
+ # embedding_matrix_2,
330
+ # threshold=threshold,
331
+ # batch_size=batch_size,
332
+ # show_progressbar=False # Disable internal progress bar
333
+ # )
334
+
335
+ # total_items = len(embedding_matrix_2)
336
+ # # Processing duplicates with a progress bar
337
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
338
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
339
+
340
+ # if similar_indices:
341
+ # duplicate_indices_in_test.append(i)
342
+ # duplicate_to_original_mapping[i] = similar_indices[0]
343
+
344
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
345
+
346
+ # def display_word_differences(x: str, y: str) -> str:
347
+ # diff = ndiff(x.split(), y.split())
348
+ # return " ".join([word for word in diff if word.startswith(('+', '-'))])
349
+
350
+ # def perform_deduplication(
351
+ # deduplication_type,
352
+ # dataset1_name,
353
+ # dataset1_split,
354
+ # dataset1_text_column,
355
+ # dataset2_name="",
356
+ # dataset2_split="",
357
+ # dataset2_text_column="",
358
+ # threshold=default_threshold,
359
+ # progress=gr.Progress(track_tqdm=True)
360
+ # ):
361
+ # try:
362
+ # # Convert threshold to float
363
+ # threshold = float(threshold)
364
+
365
+ # # Initialize status message
366
+ # status = ""
367
+
368
+ # if deduplication_type == "Single dataset":
369
+ # # Load Dataset 1
370
+ # status = "Loading Dataset 1..."
371
+ # yield status, ""
372
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
373
+ # ds = ds_default1
374
+ # else:
375
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
376
+
377
+ # # Extract texts
378
+ # status = "Extracting texts from Dataset 1..."
379
+ # yield status, ""
380
+ # texts = [example[dataset1_text_column] for example in ds]
381
+
382
+ # # Compute embeddings
383
+ # status = "Computing embeddings for Dataset 1..."
384
+ # yield status, ""
385
+ # embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
386
+
387
+ # # Deduplicate
388
+ # status = "Deduplicating embeddings..."
389
+ # yield status, ""
390
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
391
+ # embedding_matrix, threshold, progress=progress
392
+ # )
393
+
394
+ # # Prepare the results
395
+ # num_duplicates = len(duplicate_to_original_mapping)
396
+ # num_total = len(texts)
397
+ # num_deduplicated = len(deduplicated_indices)
398
+
399
+ # result_text = f"**Total documents:** {num_total}\n"
400
+ # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
401
+ # result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
402
+
403
+ # # Show deduplicated examples
404
+ # if num_duplicates > 0:
405
+ # result_text += "**Examples of duplicates found:**\n\n"
406
+ # num_examples = min(5, num_duplicates)
407
+ # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
408
+ # original_text = texts[original_idx]
409
+ # duplicate_text = texts[duplicate_idx]
410
+ # differences = display_word_differences(original_text, duplicate_text)
411
+ # result_text += f"**Original text:**\n{original_text}\n\n"
412
+ # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
413
+ # result_text += f"**Differences:**\n{differences}\n"
414
+ # result_text += "-" * 50 + "\n\n"
415
+ # else:
416
+ # result_text += "No duplicates found."
417
+
418
+ # # Final status
419
+ # status = "Deduplication completed."
420
+ # yield status, result_text
421
+
422
+ # elif deduplication_type == "Cross-dataset":
423
+ # # Load Dataset 1
424
+ # status = "Loading Dataset 1..."
425
+ # yield status, ""
426
+ # if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
427
+ # ds1 = ds_default1
428
+ # else:
429
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
430
+
431
+ # # Load Dataset 2
432
+ # status = "Loading Dataset 2..."
433
+ # yield status, ""
434
+ # if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
435
+ # ds2 = ds_default2
436
+ # else:
437
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
438
+
439
+ # # Extract texts from Dataset 1
440
+ # status = "Extracting texts from Dataset 1..."
441
+ # yield status, ""
442
+ # texts1 = [example[dataset1_text_column] for example in ds1]
443
+
444
+ # # Extract texts from Dataset 2
445
+ # status = "Extracting texts from Dataset 2..."
446
+ # yield status, ""
447
+ # texts2 = [example[dataset2_text_column] for example in ds2]
448
+
449
+ # # Compute embeddings for Dataset 1
450
+ # status = "Computing embeddings for Dataset 1..."
451
+ # yield status, ""
452
+ # embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
453
+
454
+ # # Compute embeddings for Dataset 2
455
+ # status = "Computing embeddings for Dataset 2..."
456
+ # yield status, ""
457
+ # embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
458
+
459
+ # # Deduplicate across datasets
460
+ # status = "Deduplicating embeddings across datasets..."
461
+ # yield status, ""
462
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
463
+ # embedding_matrix1, embedding_matrix2, threshold, progress=progress
464
+ # )
465
+
466
+ # num_duplicates = len(duplicate_indices_in_ds2)
467
+ # num_total_ds2 = len(texts2)
468
+ # num_unique_ds2 = num_total_ds2 - num_duplicates
469
+
470
+ # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
471
+ # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
472
+ # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
473
+
474
+ # # Show deduplicated examples
475
+ # if num_duplicates > 0:
476
+ # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
477
+ # num_examples = min(5, num_duplicates)
478
+ # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
479
+ # original_idx = duplicate_to_original_mapping[duplicate_idx]
480
+ # original_text = texts1[original_idx]
481
+ # duplicate_text = texts2[duplicate_idx]
482
+ # differences = display_word_differences(original_text, duplicate_text)
483
+ # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
484
+ # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
485
+ # result_text += f"**Differences:**\n{differences}\n"
486
+ # result_text += "-" * 50 + "\n\n"
487
+ # else:
488
+ # result_text += "No duplicates found."
489
+
490
+ # # Final status
491
+ # status = "Deduplication completed."
492
+ # yield status, result_text
493
+
494
+ # except Exception as e:
495
+ # yield f"An error occurred: {e}", ""
496
+ # raise e
497
+
498
+ # with gr.Blocks() as demo:
499
+ # gr.Markdown("# Semantic Deduplication")
500
+
501
+ # deduplication_type = gr.Radio(
502
+ # choices=["Single dataset", "Cross-dataset"],
503
+ # label="Deduplication Type",
504
+ # value="Single dataset"
505
+ # )
506
+
507
+ # with gr.Row():
508
+ # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
509
+ # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
510
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
511
+
512
+ # dataset2_inputs = gr.Column(visible=False)
513
+ # with dataset2_inputs:
514
+ # gr.Markdown("### Dataset 2")
515
+ # with gr.Row():
516
+ # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
517
+ # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
518
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
519
+
520
+ # threshold = gr.Slider(
521
+ # minimum=0.0,
522
+ # maximum=1.0,
523
+ # value=default_threshold,
524
+ # label="Similarity Threshold"
525
+ # )
526
+
527
+ # compute_button = gr.Button("Compute")
528
+
529
+ # status_output = gr.Markdown()
530
+ # result_output = gr.Markdown()
531
+
532
+ # # Function to update the visibility of dataset2_inputs
533
+ # def update_visibility(deduplication_type_value):
534
+ # if deduplication_type_value == "Cross-dataset":
535
+ # return gr.update(visible=True)
536
+ # else:
537
+ # return gr.update(visible=False)
538
+
539
+ # deduplication_type.change(
540
+ # update_visibility,
541
+ # inputs=deduplication_type,
542
+ # outputs=dataset2_inputs
543
+ # )
544
+
545
+ # compute_button.click(
546
+ # fn=perform_deduplication,
547
+ # inputs=[
548
+ # deduplication_type,
549
+ # dataset1_name,
550
+ # dataset1_split,
551
+ # dataset1_text_column,
552
+ # dataset2_name,
553
+ # dataset2_split,
554
+ # dataset2_text_column,
555
+ # threshold
556
+ # ],
557
+ # outputs=[status_output, result_output]
558
+ # )
559
+
560
+ # demo.launch()
561
+
562
 
563
  # import gradio as gr
564
  # from datasets import load_dataset