Pringled commited on
Commit
58d8f1a
1 Parent(s): a9118ee
Files changed (1) hide show
  1. app.py +78 -60
app.py CHANGED
@@ -4,7 +4,8 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
- import concurrent.futures
 
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -26,9 +27,65 @@ def batch_iterable(iterable, batch_size):
26
  for i in range(0, len(iterable), batch_size):
27
  yield iterable[i:i + batch_size]
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def display_word_differences(x: str, y: str) -> str:
30
  diff = ndiff(x.split(), y.split())
31
- return " ".join([word for word in diff if word.startswith(('+', '-'))])
32
 
33
  def perform_deduplication(
34
  deduplication_type,
@@ -39,7 +96,7 @@ def perform_deduplication(
39
  dataset2_split="",
40
  dataset2_text_column="",
41
  threshold=default_threshold,
42
- progress=gr.Progress(track_tqdm=True)
43
  ):
44
  try:
45
  # Convert threshold to float
@@ -52,7 +109,10 @@ def perform_deduplication(
52
  # Load Dataset 1
53
  status = "Loading Dataset 1..."
54
  yield status, ""
55
- if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
 
 
 
56
  ds = ds_default1
57
  else:
58
  ds = load_dataset(dataset1_name, split=dataset1_split)
@@ -65,15 +125,12 @@ def perform_deduplication(
65
  # Compute embeddings
66
  status = "Computing embeddings for Dataset 1..."
67
  yield status, ""
68
- embeddings = []
69
- batch_size = 64
70
- total_batches = (len(texts) + batch_size - 1) // batch_size
71
-
72
- for batch_texts in progress.tqdm(batch_iterable(texts, batch_size), desc="Computing embeddings for Dataset 1", total=total_batches):
73
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
74
- embeddings.append(batch_embeddings)
75
-
76
- embedding_matrix = np.concatenate(embeddings, axis=0)
77
 
78
  # Deduplicate
79
  status = "Deduplicating embeddings..."
@@ -89,7 +146,9 @@ def perform_deduplication(
89
 
90
  result_text = f"**Total documents:** {num_total}\n"
91
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
92
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
 
 
93
 
94
  # Show deduplicated examples
95
  if num_duplicates > 0:
@@ -119,49 +178,13 @@ def perform_deduplication(
119
  yield f"An error occurred: {e}", ""
120
  raise e
121
 
122
-
123
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
124
- """
125
- Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
126
- """
127
- # Building the index
128
- progress(0, desc="Building search index...")
129
- reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
130
-
131
- deduplicated_indices = set(range(len(embedding_matrix)))
132
- duplicate_to_original_mapping = {}
133
-
134
- # Finding nearest neighbors
135
- progress(0, desc="Finding nearest neighbors...")
136
- results = reach.nearest_neighbor_threshold(
137
- embedding_matrix,
138
- threshold=threshold,
139
- batch_size=batch_size,
140
- show_progressbar=False # Disable internal progress bar
141
- )
142
-
143
- # Processing duplicates with a progress bar
144
- total_items = len(embedding_matrix)
145
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
146
- if i not in deduplicated_indices:
147
- continue
148
-
149
- similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
150
-
151
- for sim_idx in similar_indices:
152
- if sim_idx in deduplicated_indices:
153
- deduplicated_indices.remove(sim_idx)
154
- duplicate_to_original_mapping[sim_idx] = i
155
-
156
- return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
157
-
158
  with gr.Blocks() as demo:
159
  gr.Markdown("# Semantic Deduplication")
160
 
161
  deduplication_type = gr.Radio(
162
  choices=["Single dataset", "Cross-dataset"],
163
  label="Deduplication Type",
164
- value="Single dataset"
165
  )
166
 
167
  with gr.Row():
@@ -178,10 +201,7 @@ with gr.Blocks() as demo:
178
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
179
 
180
  threshold = gr.Slider(
181
- minimum=0.0,
182
- maximum=1.0,
183
- value=default_threshold,
184
- label="Similarity Threshold"
185
  )
186
 
187
  compute_button = gr.Button("Compute")
@@ -197,9 +217,7 @@ with gr.Blocks() as demo:
197
  return gr.update(visible=False)
198
 
199
  deduplication_type.change(
200
- update_visibility,
201
- inputs=deduplication_type,
202
- outputs=dataset2_inputs
203
  )
204
 
205
  compute_button.click(
@@ -212,9 +230,9 @@ with gr.Blocks() as demo:
212
  dataset2_name,
213
  dataset2_split,
214
  dataset2_text_column,
215
- threshold
216
  ],
217
- outputs=[status_output, result_output]
218
  )
219
 
220
  demo.launch()
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
+ import tqdm
8
+ from contextlib import contextmanager
9
 
10
  # Load the model at startup
11
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
 
27
  for i in range(0, len(iterable), batch_size):
28
  yield iterable[i:i + batch_size]
29
 
30
+ @contextmanager
31
+ def tqdm_redirect(progress):
32
+ original_tqdm = tqdm.tqdm
33
+ try:
34
+ tqdm.tqdm = progress.tqdm
35
+ yield
36
+ finally:
37
+ tqdm.tqdm = original_tqdm
38
+
39
+ def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
40
+ with tqdm_redirect(progress):
41
+ embeddings = model.encode(texts, show_progressbar=True, batch_size=batch_size)
42
+ return embeddings
43
+
44
+ def deduplicate(
45
+ embedding_matrix: np.ndarray,
46
+ threshold: float,
47
+ batch_size: int = 1024,
48
+ progress=None
49
+ ) -> tuple[np.ndarray, dict[int, int]]:
50
+ # Existing deduplication code remains unchanged
51
+ # Building the index
52
+ progress(0, desc="Building search index...")
53
+ reach = Reach(
54
+ vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
55
+ )
56
+
57
+ deduplicated_indices = set(range(len(embedding_matrix)))
58
+ duplicate_to_original_mapping = {}
59
+
60
+ # Finding nearest neighbors
61
+ progress(0, desc="Finding nearest neighbors...")
62
+ results = reach.nearest_neighbor_threshold(
63
+ embedding_matrix,
64
+ threshold=threshold,
65
+ batch_size=batch_size,
66
+ show_progressbar=False, # Disable internal progress bar
67
+ )
68
+
69
+ # Processing duplicates with a progress bar
70
+ total_items = len(embedding_matrix)
71
+ for i, similar_items in enumerate(
72
+ progress.tqdm(results, desc="Processing duplicates", total=total_items)
73
+ ):
74
+ if i not in deduplicated_indices:
75
+ continue
76
+
77
+ similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
78
+
79
+ for sim_idx in similar_indices:
80
+ if sim_idx in deduplicated_indices:
81
+ deduplicated_indices.remove(sim_idx)
82
+ duplicate_to_original_mapping[sim_idx] = i
83
+
84
+ return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
85
+
86
  def display_word_differences(x: str, y: str) -> str:
87
  diff = ndiff(x.split(), y.split())
88
+ return " ".join([word for word in diff if word.startswith(("+", "-"))])
89
 
90
  def perform_deduplication(
91
  deduplication_type,
 
96
  dataset2_split="",
97
  dataset2_text_column="",
98
  threshold=default_threshold,
99
+ progress=gr.Progress(track_tqdm=True),
100
  ):
101
  try:
102
  # Convert threshold to float
 
109
  # Load Dataset 1
110
  status = "Loading Dataset 1..."
111
  yield status, ""
112
+ if (
113
+ dataset1_name == default_dataset1_name
114
+ and dataset1_split == default_dataset1_split
115
+ ):
116
  ds = ds_default1
117
  else:
118
  ds = load_dataset(dataset1_name, split=dataset1_split)
 
125
  # Compute embeddings
126
  status = "Computing embeddings for Dataset 1..."
127
  yield status, ""
128
+ embedding_matrix = compute_embeddings(
129
+ texts,
130
+ batch_size=64,
131
+ progress=progress,
132
+ desc="Computing embeddings for Dataset 1",
133
+ )
 
 
 
134
 
135
  # Deduplicate
136
  status = "Deduplicating embeddings..."
 
146
 
147
  result_text = f"**Total documents:** {num_total}\n"
148
  result_text += f"**Number of duplicates found:** {num_duplicates}\n"
149
+ result_text += (
150
+ f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
151
+ )
152
 
153
  # Show deduplicated examples
154
  if num_duplicates > 0:
 
178
  yield f"An error occurred: {e}", ""
179
  raise e
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with gr.Blocks() as demo:
182
  gr.Markdown("# Semantic Deduplication")
183
 
184
  deduplication_type = gr.Radio(
185
  choices=["Single dataset", "Cross-dataset"],
186
  label="Deduplication Type",
187
+ value="Single dataset",
188
  )
189
 
190
  with gr.Row():
 
201
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
202
 
203
  threshold = gr.Slider(
204
+ minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
 
 
 
205
  )
206
 
207
  compute_button = gr.Button("Compute")
 
217
  return gr.update(visible=False)
218
 
219
  deduplication_type.change(
220
+ update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
 
 
221
  )
222
 
223
  compute_button.click(
 
230
  dataset2_name,
231
  dataset2_split,
232
  dataset2_text_column,
233
+ threshold,
234
  ],
235
+ outputs=[status_output, result_output],
236
  )
237
 
238
  demo.launch()