Pringled commited on
Commit
a81fb12
1 Parent(s): 82a1d00
Files changed (1) hide show
  1. app.py +38 -39
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
- from model2vec import StaticModel
5
  import model2vec
6
  from reach import Reach
7
  from difflib import ndiff
8
 
 
9
  # Load the model at startup
10
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
11
 
12
  # Default dataset parameters
13
  default_dataset1_name = "sst2"
@@ -23,43 +24,43 @@ ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
23
 
24
 
25
  # Patch tqdm to use Gradio's progress bar
26
- from tqdm import tqdm as original_tqdm
27
 
28
  # Patch tqdm to use Gradio's progress bar
29
  # Patch tqdm to use Gradio's progress bar
30
- def patch_tqdm_for_gradio(progress):
31
- class GradioTqdm(original_tqdm):
32
- def __init__(self, *args, **kwargs):
33
- super().__init__(*args, **kwargs)
34
- self.progress = progress
35
- self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
36
- self.update_interval = max(1, self.total_batches // 100) # Update every 1%
37
-
38
- def update(self, n=1):
39
- super().update(n)
40
- # Update Gradio progress bar every update_interval steps
41
- if self.n % self.update_interval == 0 or self.n == self.total_batches:
42
- self.progress(self.n / self.total_batches)
43
-
44
- return GradioTqdm
45
-
46
- def patch_model2vec_tqdm(progress):
47
- patched_tqdm = patch_tqdm_for_gradio(progress)
48
- model2vec.tqdm = patched_tqdm # Replace tqdm in model2vec
49
-
50
- # Function to patch the original encode function with our Gradio tqdm
51
- def original_encode_with_tqdm(original_encode_func, patched_tqdm):
52
- def new_encode(*args, **kwargs):
53
- original_tqdm_backup = original_tqdm
54
- try:
55
- # Patch the `tqdm` within encode
56
- globals()['tqdm'] = patched_tqdm
57
- return original_encode_func(*args, **kwargs)
58
- finally:
59
- # Restore original tqdm after calling encode
60
- globals()['tqdm'] = original_tqdm_backup
61
-
62
- return new_encode
63
 
64
 
65
  def batch_iterable(iterable, batch_size):
@@ -157,12 +158,10 @@ def perform_deduplication(
157
  texts = [example[dataset1_text_column] for example in ds]
158
 
159
  #patched_tqdm = patch_tqdm_for_gradio(progress)
160
- patch_model2vec_tqdm(progress)
161
  #model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
162
  # Compute embeddings
163
  status = "Computing embeddings for Dataset 1..."
164
-
165
- # Remove?
166
  yield status, ""
167
 
168
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
+ #from model2vec import StaticModel
5
  import model2vec
6
  from reach import Reach
7
  from difflib import ndiff
8
 
9
+
10
  # Load the model at startup
11
+ model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
12
 
13
  # Default dataset parameters
14
  default_dataset1_name = "sst2"
 
24
 
25
 
26
  # Patch tqdm to use Gradio's progress bar
27
+ #from tqdm import tqdm as original_tqdm
28
 
29
  # Patch tqdm to use Gradio's progress bar
30
  # Patch tqdm to use Gradio's progress bar
31
+ # def patch_tqdm_for_gradio(progress):
32
+ # class GradioTqdm(original_tqdm):
33
+ # def __init__(self, *args, **kwargs):
34
+ # super().__init__(*args, **kwargs)
35
+ # self.progress = progress
36
+ # self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
37
+ # self.update_interval = max(1, self.total_batches // 100) # Update every 1%
38
+
39
+ # def update(self, n=1):
40
+ # super().update(n)
41
+ # # Update Gradio progress bar every update_interval steps
42
+ # if self.n % self.update_interval == 0 or self.n == self.total_batches:
43
+ # self.progress(self.n / self.total_batches)
44
+
45
+ # return GradioTqdm
46
+
47
+ # def patch_model2vec_tqdm(progress):
48
+ # patched_tqdm = patch_tqdm_for_gradio(progress)
49
+ # model2vec.tqdm = patched_tqdm # Replace tqdm in model2vec
50
+
51
+ # # Function to patch the original encode function with our Gradio tqdm
52
+ # def original_encode_with_tqdm(original_encode_func, patched_tqdm):
53
+ # def new_encode(*args, **kwargs):
54
+ # original_tqdm_backup = original_tqdm
55
+ # try:
56
+ # # Patch the `tqdm` within encode
57
+ # globals()['tqdm'] = patched_tqdm
58
+ # return original_encode_func(*args, **kwargs)
59
+ # finally:
60
+ # # Restore original tqdm after calling encode
61
+ # globals()['tqdm'] = original_tqdm_backup
62
+
63
+ # return new_encode
64
 
65
 
66
  def batch_iterable(iterable, batch_size):
 
158
  texts = [example[dataset1_text_column] for example in ds]
159
 
160
  #patched_tqdm = patch_tqdm_for_gradio(progress)
161
+ #patch_model2vec_tqdm(progress)
162
  #model.encode = original_encode_with_tqdm(model.encode, patched_tqdm)
163
  # Compute embeddings
164
  status = "Computing embeddings for Dataset 1..."
 
 
165
  yield status, ""
166
 
167