asoria HF staff commited on
Commit
9c726b4
1 Parent(s): cd5f2d1

Adding num rows

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -15,7 +15,7 @@ from sentence_transformers import SentenceTransformer
15
  from dotenv import load_dotenv
16
  import os
17
 
18
- import spaces
19
  import gradio as gr
20
 
21
 
@@ -81,7 +81,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
81
  return df[column].tolist()
82
 
83
 
84
- @spaces.GPU
85
  def calculate_embeddings(docs):
86
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
87
 
@@ -92,7 +92,7 @@ def calculate_n_neighbors_and_components(n_rows):
92
  return n_neighbors, n_components
93
 
94
 
95
- @spaces.GPU
96
  def fit_model(docs, embeddings, n_neighbors, n_components):
97
  global global_topic_model
98
 
@@ -116,11 +116,11 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
116
  new_model = BERTopic(
117
  language="english",
118
  # Sub-models
119
- embedding_model=sentence_model,
120
- umap_model=umap_model,
121
- hdbscan_model=hdbscan_model,
122
- representation_model=representation_model,
123
- vectorizer_model=vectorizer_model,
124
  # Hyperparameters
125
  top_n_words=10,
126
  verbose=True,
@@ -162,12 +162,16 @@ def generate_topics(dataset, config, split, column, nested_column):
162
  all_docs = []
163
  reduced_embeddings_list = []
164
  topics_info, topic_plot = None, None
 
 
 
 
 
 
165
  yield (
166
  gr.DataFrame(value=[], interactive=False, visible=True),
167
  gr.Plot(value=None, visible=True),
168
- gr.Label(
169
- {f"⚙️ Generating topics {dataset}": rows_processed / limit}, visible=True
170
- ),
171
  )
172
  while offset < limit:
173
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
@@ -207,10 +211,16 @@ def generate_topics(dataset, config, split, column, nested_column):
207
  rows_processed += len(docs)
208
  progress = min(rows_processed / limit, 1.0)
209
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
 
 
 
 
 
 
210
  yield (
211
  topics_info,
212
  topic_plot,
213
- gr.Label({f"⚙️ Generating topics {dataset}": progress}, visible=True),
214
  )
215
 
216
  offset += CHUNK_SIZE
@@ -219,7 +229,9 @@ def generate_topics(dataset, config, split, column, nested_column):
219
  yield (
220
  topics_info,
221
  topic_plot,
222
- gr.Label({f"✅ Generating topics {dataset}": 1.0}, visible=True),
 
 
223
  )
224
  cuda.empty_cache()
225
 
@@ -260,7 +272,7 @@ with gr.Blocks() as demo:
260
 
261
  generate_button = gr.Button("Generate Topics", variant="primary")
262
 
263
- gr.Markdown("## Datamap")
264
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
265
  topics_plot = gr.Plot()
266
  with gr.Accordion("Topics Info", open=False):
 
15
  from dotenv import load_dotenv
16
  import os
17
 
18
+ # import spaces
19
  import gradio as gr
20
 
21
 
 
81
  return df[column].tolist()
82
 
83
 
84
+ # @spaces.GPU
85
  def calculate_embeddings(docs):
86
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
87
 
 
92
  return n_neighbors, n_components
93
 
94
 
95
+ # @spaces.GPU
96
  def fit_model(docs, embeddings, n_neighbors, n_components):
97
  global global_topic_model
98
 
 
116
  new_model = BERTopic(
117
  language="english",
118
  # Sub-models
119
+ embedding_model=sentence_model, # Step 1 - Extract embeddings
120
+ umap_model=umap_model, # Step 2 - UMAP model
121
+ hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
122
+ vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
123
+ representation_model=representation_model, # Step 5 - Label topics
124
  # Hyperparameters
125
  top_n_words=10,
126
  verbose=True,
 
162
  all_docs = []
163
  reduced_embeddings_list = []
164
  topics_info, topic_plot = None, None
165
+ full_processing = split_rows <= MAX_ROWS
166
+ message = (
167
+ f"⚙️ Processing full dataset: 0 of ({split_rows} rows)"
168
+ if full_processing
169
+ else f"⚙️ Processing partial dataset 0 of ({limit} rows)"
170
+ )
171
  yield (
172
  gr.DataFrame(value=[], interactive=False, visible=True),
173
  gr.Plot(value=None, visible=True),
174
+ gr.Label({message: rows_processed / limit}, visible=True),
 
 
175
  )
176
  while offset < limit:
177
  docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
 
211
  rows_processed += len(docs)
212
  progress = min(rows_processed / limit, 1.0)
213
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
214
+ message = (
215
+ f"⚙️ Processing full dataset: {rows_processed} of {limit}"
216
+ if full_processing
217
+ else f"⚙️ Processing partial dataset: {rows_processed} of {limit} rows"
218
+ )
219
+
220
  yield (
221
  topics_info,
222
  topic_plot,
223
+ gr.Label({message: progress}, visible=True),
224
  )
225
 
226
  offset += CHUNK_SIZE
 
229
  yield (
230
  topics_info,
231
  topic_plot,
232
+ gr.Label(
233
+ {f"✅ Done: {rows_processed} rows have been processed": 1.0}, visible=True
234
+ ),
235
  )
236
  cuda.empty_cache()
237
 
 
272
 
273
  generate_button = gr.Button("Generate Topics", variant="primary")
274
 
275
+ gr.Markdown("## Data map")
276
  full_topics_generation_label = gr.Label(visible=False, show_label=False)
277
  topics_plot = gr.Plot()
278
  with gr.Accordion("Topics Info", open=False):