Updated app with code for deduplication
Browse files
app.py
CHANGED
@@ -108,24 +108,31 @@ def perform_deduplication(
|
|
108 |
# Convert threshold to float
|
109 |
threshold = float(threshold)
|
110 |
|
|
|
|
|
|
|
111 |
if deduplication_type == "Single dataset":
|
112 |
# Load Dataset 1
|
113 |
-
|
|
|
114 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
115 |
ds = ds_default1
|
116 |
else:
|
117 |
ds = load_dataset(dataset1_name, split=dataset1_split)
|
118 |
|
119 |
# Extract texts
|
120 |
-
|
|
|
121 |
texts = [example[dataset1_text_column] for example in ds]
|
122 |
|
123 |
# Compute embeddings
|
124 |
-
|
|
|
125 |
embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
|
126 |
|
127 |
# Deduplicate
|
128 |
-
|
|
|
129 |
deduplicated_indices, duplicate_to_original_mapping = deduplicate(
|
130 |
embedding_matrix, threshold
|
131 |
)
|
@@ -154,41 +161,50 @@ def perform_deduplication(
|
|
154 |
else:
|
155 |
result_text += "No duplicates found."
|
156 |
|
157 |
-
|
|
|
|
|
158 |
|
159 |
elif deduplication_type == "Cross-dataset":
|
160 |
# Load Dataset 1
|
161 |
-
|
|
|
162 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
163 |
ds1 = ds_default1
|
164 |
else:
|
165 |
ds1 = load_dataset(dataset1_name, split=dataset1_split)
|
166 |
|
167 |
# Load Dataset 2
|
168 |
-
|
|
|
169 |
if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
|
170 |
ds2 = ds_default2
|
171 |
else:
|
172 |
ds2 = load_dataset(dataset2_name, split=dataset2_split)
|
173 |
|
174 |
# Extract texts from Dataset 1
|
175 |
-
|
|
|
176 |
texts1 = [example[dataset1_text_column] for example in ds1]
|
177 |
|
178 |
# Extract texts from Dataset 2
|
179 |
-
|
|
|
180 |
texts2 = [example[dataset2_text_column] for example in ds2]
|
181 |
|
182 |
# Compute embeddings for Dataset 1
|
183 |
-
|
|
|
184 |
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
|
185 |
|
186 |
# Compute embeddings for Dataset 2
|
187 |
-
|
|
|
188 |
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
|
189 |
|
190 |
# Deduplicate across datasets
|
191 |
-
|
|
|
192 |
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
|
193 |
embedding_matrix1, embedding_matrix2, threshold
|
194 |
)
|
@@ -217,7 +233,9 @@ def perform_deduplication(
|
|
217 |
else:
|
218 |
result_text += "No duplicates found."
|
219 |
|
220 |
-
|
|
|
|
|
221 |
|
222 |
finally:
|
223 |
# Restore original tqdm
|
@@ -257,7 +275,8 @@ with gr.Blocks() as demo:
|
|
257 |
|
258 |
compute_button = gr.Button("Compute")
|
259 |
|
260 |
-
|
|
|
261 |
|
262 |
# Function to update the visibility of dataset2_inputs
|
263 |
def update_visibility(deduplication_type_value):
|
@@ -284,9 +303,9 @@ with gr.Blocks() as demo:
|
|
284 |
dataset2_text_column,
|
285 |
threshold
|
286 |
],
|
287 |
-
outputs=
|
288 |
)
|
289 |
-
|
290 |
demo.launch()
|
291 |
|
292 |
|
|
|
108 |
# Convert threshold to float
|
109 |
threshold = float(threshold)
|
110 |
|
111 |
+
# Initialize status message
|
112 |
+
status = ""
|
113 |
+
|
114 |
if deduplication_type == "Single dataset":
|
115 |
# Load Dataset 1
|
116 |
+
status = "Loading Dataset 1..."
|
117 |
+
yield status, ""
|
118 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
119 |
ds = ds_default1
|
120 |
else:
|
121 |
ds = load_dataset(dataset1_name, split=dataset1_split)
|
122 |
|
123 |
# Extract texts
|
124 |
+
status = "Extracting texts from Dataset 1..."
|
125 |
+
yield status, ""
|
126 |
texts = [example[dataset1_text_column] for example in ds]
|
127 |
|
128 |
# Compute embeddings
|
129 |
+
status = "Computing embeddings for Dataset 1..."
|
130 |
+
yield status, ""
|
131 |
embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
|
132 |
|
133 |
# Deduplicate
|
134 |
+
status = "Deduplicating embeddings..."
|
135 |
+
yield status, ""
|
136 |
deduplicated_indices, duplicate_to_original_mapping = deduplicate(
|
137 |
embedding_matrix, threshold
|
138 |
)
|
|
|
161 |
else:
|
162 |
result_text += "No duplicates found."
|
163 |
|
164 |
+
# Final status
|
165 |
+
status = "Deduplication completed."
|
166 |
+
yield status, result_text
|
167 |
|
168 |
elif deduplication_type == "Cross-dataset":
|
169 |
# Load Dataset 1
|
170 |
+
status = "Loading Dataset 1..."
|
171 |
+
yield status, ""
|
172 |
if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
|
173 |
ds1 = ds_default1
|
174 |
else:
|
175 |
ds1 = load_dataset(dataset1_name, split=dataset1_split)
|
176 |
|
177 |
# Load Dataset 2
|
178 |
+
status = "Loading Dataset 2..."
|
179 |
+
yield status, ""
|
180 |
if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
|
181 |
ds2 = ds_default2
|
182 |
else:
|
183 |
ds2 = load_dataset(dataset2_name, split=dataset2_split)
|
184 |
|
185 |
# Extract texts from Dataset 1
|
186 |
+
status = "Extracting texts from Dataset 1..."
|
187 |
+
yield status, ""
|
188 |
texts1 = [example[dataset1_text_column] for example in ds1]
|
189 |
|
190 |
# Extract texts from Dataset 2
|
191 |
+
status = "Extracting texts from Dataset 2..."
|
192 |
+
yield status, ""
|
193 |
texts2 = [example[dataset2_text_column] for example in ds2]
|
194 |
|
195 |
# Compute embeddings for Dataset 1
|
196 |
+
status = "Computing embeddings for Dataset 1..."
|
197 |
+
yield status, ""
|
198 |
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
|
199 |
|
200 |
# Compute embeddings for Dataset 2
|
201 |
+
status = "Computing embeddings for Dataset 2..."
|
202 |
+
yield status, ""
|
203 |
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
|
204 |
|
205 |
# Deduplicate across datasets
|
206 |
+
status = "Deduplicating embeddings across datasets..."
|
207 |
+
yield status, ""
|
208 |
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
|
209 |
embedding_matrix1, embedding_matrix2, threshold
|
210 |
)
|
|
|
233 |
else:
|
234 |
result_text += "No duplicates found."
|
235 |
|
236 |
+
# Final status
|
237 |
+
status = "Deduplication completed."
|
238 |
+
yield status, result_text
|
239 |
|
240 |
finally:
|
241 |
# Restore original tqdm
|
|
|
275 |
|
276 |
compute_button = gr.Button("Compute")
|
277 |
|
278 |
+
status_output = gr.Markdown()
|
279 |
+
result_output = gr.Markdown()
|
280 |
|
281 |
# Function to update the visibility of dataset2_inputs
|
282 |
def update_visibility(deduplication_type_value):
|
|
|
303 |
dataset2_text_column,
|
304 |
threshold
|
305 |
],
|
306 |
+
outputs=[status_output, result_output]
|
307 |
)
|
308 |
+
|
309 |
demo.launch()
|
310 |
|
311 |
|