Pringled commited on
Commit
9f13004
1 Parent(s): 5422464

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -108,24 +108,31 @@ def perform_deduplication(
108
  # Convert threshold to float
109
  threshold = float(threshold)
110
 
 
 
 
111
  if deduplication_type == "Single dataset":
112
  # Load Dataset 1
113
- gr.print("Loading Dataset 1...")
 
114
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
115
  ds = ds_default1
116
  else:
117
  ds = load_dataset(dataset1_name, split=dataset1_split)
118
 
119
  # Extract texts
120
- gr.print("Extracting texts from Dataset 1...")
 
121
  texts = [example[dataset1_text_column] for example in ds]
122
 
123
  # Compute embeddings
124
- gr.print("Computing embeddings for Dataset 1...")
 
125
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
126
 
127
  # Deduplicate
128
- gr.print("Deduplicating embeddings...")
 
129
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
130
  embedding_matrix, threshold
131
  )
@@ -154,41 +161,50 @@ def perform_deduplication(
154
  else:
155
  result_text += "No duplicates found."
156
 
157
- return result_text
 
 
158
 
159
  elif deduplication_type == "Cross-dataset":
160
  # Load Dataset 1
161
- gr.print("Loading Dataset 1...")
 
162
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
163
  ds1 = ds_default1
164
  else:
165
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
166
 
167
  # Load Dataset 2
168
- gr.print("Loading Dataset 2...")
 
169
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
170
  ds2 = ds_default2
171
  else:
172
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
173
 
174
  # Extract texts from Dataset 1
175
- gr.print("Extracting texts from Dataset 1...")
 
176
  texts1 = [example[dataset1_text_column] for example in ds1]
177
 
178
  # Extract texts from Dataset 2
179
- gr.print("Extracting texts from Dataset 2...")
 
180
  texts2 = [example[dataset2_text_column] for example in ds2]
181
 
182
  # Compute embeddings for Dataset 1
183
- gr.print("Computing embeddings for Dataset 1...")
 
184
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
185
 
186
  # Compute embeddings for Dataset 2
187
- gr.print("Computing embeddings for Dataset 2...")
 
188
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
189
 
190
  # Deduplicate across datasets
191
- gr.print("Deduplicating embeddings across datasets...")
 
192
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
193
  embedding_matrix1, embedding_matrix2, threshold
194
  )
@@ -217,7 +233,9 @@ def perform_deduplication(
217
  else:
218
  result_text += "No duplicates found."
219
 
220
- return result_text
 
 
221
 
222
  finally:
223
  # Restore original tqdm
@@ -257,7 +275,8 @@ with gr.Blocks() as demo:
257
 
258
  compute_button = gr.Button("Compute")
259
 
260
- output = gr.Markdown()
 
261
 
262
  # Function to update the visibility of dataset2_inputs
263
  def update_visibility(deduplication_type_value):
@@ -284,9 +303,9 @@ with gr.Blocks() as demo:
284
  dataset2_text_column,
285
  threshold
286
  ],
287
- outputs=output
288
  )
289
-
290
  demo.launch()
291
 
292
 
 
108
  # Convert threshold to float
109
  threshold = float(threshold)
110
 
111
+ # Initialize status message
112
+ status = ""
113
+
114
  if deduplication_type == "Single dataset":
115
  # Load Dataset 1
116
+ status = "Loading Dataset 1..."
117
+ yield status, ""
118
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
119
  ds = ds_default1
120
  else:
121
  ds = load_dataset(dataset1_name, split=dataset1_split)
122
 
123
  # Extract texts
124
+ status = "Extracting texts from Dataset 1..."
125
+ yield status, ""
126
  texts = [example[dataset1_text_column] for example in ds]
127
 
128
  # Compute embeddings
129
+ status = "Computing embeddings for Dataset 1..."
130
+ yield status, ""
131
  embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
132
 
133
  # Deduplicate
134
+ status = "Deduplicating embeddings..."
135
+ yield status, ""
136
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
137
  embedding_matrix, threshold
138
  )
 
161
  else:
162
  result_text += "No duplicates found."
163
 
164
+ # Final status
165
+ status = "Deduplication completed."
166
+ yield status, result_text
167
 
168
  elif deduplication_type == "Cross-dataset":
169
  # Load Dataset 1
170
+ status = "Loading Dataset 1..."
171
+ yield status, ""
172
  if dataset1_name == default_dataset1_name and dataset1_split == default_dataset1_split:
173
  ds1 = ds_default1
174
  else:
175
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
176
 
177
  # Load Dataset 2
178
+ status = "Loading Dataset 2..."
179
+ yield status, ""
180
  if dataset2_name == default_dataset2_name and dataset2_split == default_dataset2_split:
181
  ds2 = ds_default2
182
  else:
183
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
184
 
185
  # Extract texts from Dataset 1
186
+ status = "Extracting texts from Dataset 1..."
187
+ yield status, ""
188
  texts1 = [example[dataset1_text_column] for example in ds1]
189
 
190
  # Extract texts from Dataset 2
191
+ status = "Extracting texts from Dataset 2..."
192
+ yield status, ""
193
  texts2 = [example[dataset2_text_column] for example in ds2]
194
 
195
  # Compute embeddings for Dataset 1
196
+ status = "Computing embeddings for Dataset 1..."
197
+ yield status, ""
198
  embedding_matrix1 = model.encode(texts1, show_progressbar=True)
199
 
200
  # Compute embeddings for Dataset 2
201
+ status = "Computing embeddings for Dataset 2..."
202
+ yield status, ""
203
  embedding_matrix2 = model.encode(texts2, show_progressbar=True)
204
 
205
  # Deduplicate across datasets
206
+ status = "Deduplicating embeddings across datasets..."
207
+ yield status, ""
208
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
209
  embedding_matrix1, embedding_matrix2, threshold
210
  )
 
233
  else:
234
  result_text += "No duplicates found."
235
 
236
+ # Final status
237
+ status = "Deduplication completed."
238
+ yield status, result_text
239
 
240
  finally:
241
  # Restore original tqdm
 
275
 
276
  compute_button = gr.Button("Compute")
277
 
278
+ status_output = gr.Markdown()
279
+ result_output = gr.Markdown()
280
 
281
  # Function to update the visibility of dataset2_inputs
282
  def update_visibility(deduplication_type_value):
 
303
  dataset2_text_column,
304
  threshold
305
  ],
306
+ outputs=[status_output, result_output]
307
  )
308
+
309
  demo.launch()
310
 
311