Pringled commited on
Commit
1d331c4
1 Parent(s): d90d4c0
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -20,6 +20,35 @@ default_threshold = 0.9
20
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
21
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def batch_iterable(iterable, batch_size):
24
  """Helper function to create batches from an iterable."""
25
  for i in range(0, len(iterable), batch_size):
@@ -114,15 +143,18 @@ def perform_deduplication(
114
  yield status, ""
115
  texts = [example[dataset1_text_column] for example in ds]
116
 
 
 
117
  # Compute embeddings
118
  status = "Computing embeddings for Dataset 1..."
119
  yield status, ""
120
- embedding_matrix = compute_embeddings(
121
- texts,
122
- batch_size=64,
123
- progress=progress,
124
- desc="Computing embeddings for Dataset 1",
125
- )
 
126
 
127
  # Deduplicate
128
  status = "Deduplicating embeddings..."
 
20
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
21
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
22
 
23
+ from tqdm import tqdm as original_tqdm
24
+ # Patch tqdm to use Gradio's progress bar
25
+ def patch_tqdm_for_gradio(progress):
26
+ class GradioTqdm(original_tqdm):
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self.progress = progress
30
+ self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
31
+
32
+ def update(self, n=1):
33
+ super().update(n)
34
+ self.progress(self.n / self.total_batches)
35
+
36
+ return GradioTqdm
37
+ # Function to patch the original encode function with our Gradio tqdm
38
+ def original_encode_with_tqdm(original_encode_func, patched_tqdm):
39
+ def new_encode(*args, **kwargs):
40
+ # Replace tqdm with our patched version
41
+ original_tqdm_backup = original_tqdm
42
+ try:
43
+ # Patch the `tqdm` within encode
44
+ globals()['tqdm'] = patched_tqdm
45
+ return original_encode_func(*args, **kwargs)
46
+ finally:
47
+ # Restore original tqdm after calling encode
48
+ globals()['tqdm'] = original_tqdm_backup
49
+
50
+ return new_encode
51
+
52
  def batch_iterable(iterable, batch_size):
53
  """Helper function to create batches from an iterable."""
54
  for i in range(0, len(iterable), batch_size):
 
143
  yield status, ""
144
  texts = [example[dataset1_text_column] for example in ds]
145
 
146
+ patched_tqdm = patch_tqdm_for_gradio(progress)
147
+ model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
148
  # Compute embeddings
149
  status = "Computing embeddings for Dataset 1..."
150
  yield status, ""
151
+ embedding_matrix = model.encode(texts, show_progressbar=True)
152
+ # embedding_matrix = compute_embeddings(
153
+ # texts,
154
+ # batch_size=64,
155
+ # progress=progress,
156
+ # desc="Computing embeddings for Dataset 1",
157
+ # )
158
 
159
  # Deduplicate
160
  status = "Deduplicating embeddings..."