Updates
Browse files
app.py
CHANGED
@@ -20,6 +20,35 @@ default_threshold = 0.9
|
|
20 |
ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
|
21 |
ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def batch_iterable(iterable, batch_size):
|
24 |
"""Helper function to create batches from an iterable."""
|
25 |
for i in range(0, len(iterable), batch_size):
|
@@ -114,15 +143,18 @@ def perform_deduplication(
|
|
114 |
yield status, ""
|
115 |
texts = [example[dataset1_text_column] for example in ds]
|
116 |
|
|
|
|
|
117 |
# Compute embeddings
|
118 |
status = "Computing embeddings for Dataset 1..."
|
119 |
yield status, ""
|
120 |
-
embedding_matrix =
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
|
127 |
# Deduplicate
|
128 |
status = "Deduplicating embeddings..."
|
|
|
20 |
ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
|
21 |
ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
|
22 |
|
23 |
+
from tqdm import tqdm as original_tqdm
|
24 |
+
# Patch tqdm to use Gradio's progress bar
|
25 |
+
def patch_tqdm_for_gradio(progress):
|
26 |
+
class GradioTqdm(original_tqdm):
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
super().__init__(*args, **kwargs)
|
29 |
+
self.progress = progress
|
30 |
+
self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
|
31 |
+
|
32 |
+
def update(self, n=1):
|
33 |
+
super().update(n)
|
34 |
+
self.progress(self.n / self.total_batches)
|
35 |
+
|
36 |
+
return GradioTqdm
|
37 |
+
# Function to patch the original encode function with our Gradio tqdm
|
38 |
+
def original_encode_with_tqdm(original_encode_func, patched_tqdm):
|
39 |
+
def new_encode(*args, **kwargs):
|
40 |
+
# Replace tqdm with our patched version
|
41 |
+
original_tqdm_backup = original_tqdm
|
42 |
+
try:
|
43 |
+
# Patch the `tqdm` within encode
|
44 |
+
globals()['tqdm'] = patched_tqdm
|
45 |
+
return original_encode_func(*args, **kwargs)
|
46 |
+
finally:
|
47 |
+
# Restore original tqdm after calling encode
|
48 |
+
globals()['tqdm'] = original_tqdm_backup
|
49 |
+
|
50 |
+
return new_encode
|
51 |
+
|
52 |
def batch_iterable(iterable, batch_size):
|
53 |
"""Helper function to create batches from an iterable."""
|
54 |
for i in range(0, len(iterable), batch_size):
|
|
|
143 |
yield status, ""
|
144 |
texts = [example[dataset1_text_column] for example in ds]
|
145 |
|
146 |
+
patched_tqdm = patch_tqdm_for_gradio(progress)
|
147 |
+
model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
|
148 |
# Compute embeddings
|
149 |
status = "Computing embeddings for Dataset 1..."
|
150 |
yield status, ""
|
151 |
+
embedding_matrix = model.encode(texts, show_progressbar=True)
|
152 |
+
# embedding_matrix = compute_embeddings(
|
153 |
+
# texts,
|
154 |
+
# batch_size=64,
|
155 |
+
# progress=progress,
|
156 |
+
# desc="Computing embeddings for Dataset 1",
|
157 |
+
# )
|
158 |
|
159 |
# Deduplicate
|
160 |
status = "Deduplicating embeddings..."
|