Pringled commited on
Commit
892ceeb
1 Parent(s): 7a1cd7a

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -3,10 +3,9 @@ from datasets import load_dataset
3
  import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
- from tqdm import tqdm
7
  from difflib import ndiff
8
 
9
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
10
  """
11
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
12
  """
@@ -24,7 +23,7 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
24
  )
25
 
26
  # Process duplicates
27
- for i, similar_items in enumerate(tqdm(results)):
28
  if i not in deduplicated_indices:
29
  continue # Skip already marked duplicates
30
 
@@ -39,7 +38,7 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
39
 
40
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
41
 
42
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
43
  """
44
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
45
  """
@@ -58,7 +57,7 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
58
  )
59
 
60
  # Process duplicates
61
- for i, similar_items in enumerate(tqdm(results)):
62
  # Similar items are returned as (index, score), we are only interested in the index
63
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
64
 
@@ -71,7 +70,7 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
71
 
72
  def display_word_differences(x: str, y: str) -> str:
73
  diff = ndiff(x.split(), y.split())
74
- return " ".join([f"{word}" for word in diff if word.startswith(('+', '-'))])
75
 
76
  def perform_deduplication(
77
  deduplication_type,
@@ -81,7 +80,8 @@ def perform_deduplication(
81
  dataset2_name,
82
  dataset2_split,
83
  dataset2_text_column,
84
- threshold
 
85
  ):
86
  # Convert threshold to float
87
  threshold = float(threshold)
@@ -98,8 +98,7 @@ def perform_deduplication(
98
  embedding_matrix = model.encode(texts, show_progressbar=True)
99
 
100
  # Deduplicate
101
- with gr.Progress(track_tqdm=True):
102
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
103
 
104
  # Prepare the results
105
  num_duplicates = len(duplicate_to_original_mapping)
@@ -114,9 +113,7 @@ def perform_deduplication(
114
  result_text += "**Examples of duplicates found:**\n\n"
115
  num_examples = min(5, num_duplicates)
116
  examples_shown = 0
117
- for duplicate_idx, original_idx in duplicate_to_original_mapping.items():
118
- if examples_shown >= num_examples:
119
- break
120
  original_text = texts[original_idx]
121
  duplicate_text = texts[duplicate_idx]
122
  differences = display_word_differences(original_text, duplicate_text)
@@ -143,8 +140,7 @@ def perform_deduplication(
143
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
144
 
145
  # Deduplicate across datasets
146
- with gr.Progress(track_tqdm=True):
147
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
148
 
149
  num_duplicates = len(duplicate_indices_in_ds2)
150
  num_total_ds2 = len(texts2)
 
3
  import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
 
6
  from difflib import ndiff
7
 
8
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[np.ndarray, dict[int, int]]:
9
  """
10
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
11
  """
 
23
  )
24
 
25
  # Process duplicates
26
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
27
  if i not in deduplicated_indices:
28
  continue # Skip already marked duplicates
29
 
 
38
 
39
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
40
 
41
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[list[int], dict[int, int]]:
42
  """
43
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
44
  """
 
57
  )
58
 
59
  # Process duplicates
60
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
61
  # Similar items are returned as (index, score), we are only interested in the index
62
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
63
 
 
70
 
71
  def display_word_differences(x: str, y: str) -> str:
72
  diff = ndiff(x.split(), y.split())
73
+ return " ".join([word for word in diff if word.startswith(('+', '-'))])
74
 
75
  def perform_deduplication(
76
  deduplication_type,
 
80
  dataset2_name,
81
  dataset2_split,
82
  dataset2_text_column,
83
+ threshold,
84
+ progress=gr.Progress(track_tqdm=True)
85
  ):
86
  # Convert threshold to float
87
  threshold = float(threshold)
 
98
  embedding_matrix = model.encode(texts, show_progressbar=True)
99
 
100
  # Deduplicate
101
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
 
102
 
103
  # Prepare the results
104
  num_duplicates = len(duplicate_to_original_mapping)
 
113
  result_text += "**Examples of duplicates found:**\n\n"
114
  num_examples = min(5, num_duplicates)
115
  examples_shown = 0
116
+ for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
 
 
117
  original_text = texts[original_idx]
118
  duplicate_text = texts[duplicate_idx]
119
  differences = display_word_differences(original_text, duplicate_text)
 
140
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
141
 
142
  # Deduplicate across datasets
143
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
 
144
 
145
  num_duplicates = len(duplicate_indices_in_ds2)
146
  num_total_ds2 = len(texts2)