fix-plot-issue

#1
by asoria HF staff - opened
Files changed (1) hide show
  1. app.py +216 -209
app.py CHANGED
@@ -37,7 +37,6 @@ DATASETS_TOPICS_ORGANIZATION = os.getenv(
37
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
38
  )
39
  USE_CUML = int(os.getenv("USE_CUML", "1"))
40
- USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
41
 
42
  # Use cuml lib only if configured
43
  if USE_CUML:
@@ -53,19 +52,17 @@ logging.basicConfig(
53
  )
54
 
55
  api = HfApi(token=HF_TOKEN)
56
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
57
 
58
- # Representation model
59
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
60
-
61
- representation_model = KeyBERTInspired()
62
  vectorizer_model = CountVectorizer(stop_words="english")
 
63
 
64
  inference_client = InferenceClient(model_id)
65
 
66
 
67
  def calculate_embeddings(docs):
68
- return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
69
 
70
 
71
  def calculate_n_neighbors_and_components(n_rows):
@@ -95,7 +92,7 @@ def fit_model(docs, embeddings, n_neighbors, n_components):
95
  new_model = BERTopic(
96
  language="english",
97
  # Sub-models
98
- embedding_model=sentence_model, # Step 1 - Extract embeddings
99
  umap_model=umap_model, # Step 2 - UMAP model
100
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
101
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
@@ -169,44 +166,146 @@ def generate_topics(dataset, config, split, column, plot_type):
169
  "",
170
  )
171
 
172
- while offset < limit:
173
- logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
174
- docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
175
- if not docs:
176
- break
177
- logging.info(f"Got {len(docs)} docs ✓")
178
- embeddings = calculate_embeddings(docs)
179
- new_model = fit_model(docs, embeddings, n_neighbors, n_components)
180
-
181
- if base_model is None:
182
- base_model = new_model
183
- logging.info(
184
- f"The following topics are newly found: {base_model.topic_labels_}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
- else:
187
- updated_model = BERTopic.merge_models([base_model, new_model])
188
- nr_new_topics = len(set(updated_model.topics_)) - len(
189
- set(base_model.topics_)
 
 
 
 
190
  )
191
- new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
192
- logging.info(f"The following topics are newly found: {new_topics}")
193
- base_model = updated_model
194
 
195
- logging.info("Reducing embeddings to 2D")
196
- reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
197
- reduced_embeddings_list.append(reduced_embeddings)
198
- logging.info("Reducing embeddings to 2D ✓")
 
 
 
199
 
200
- all_docs.extend(docs)
201
- reduced_embeddings_array = np.vstack(reduced_embeddings_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- topics_info = base_model.get_topic_info()
204
  all_topics = base_model.topics_
205
- logging.info(f"Preparing topics {plot_type} plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  topic_plot = (
207
  base_model.visualize_document_datamap(
208
  docs=all_docs,
209
  topics=all_topics,
 
210
  reduced_embeddings=reduced_embeddings_array,
211
  title="",
212
  sub_title=sub_title,
@@ -227,192 +326,100 @@ def generate_topics(dataset, config, split, column, plot_type):
227
  if plot_type == "DataMapPlot"
228
  else base_model.visualize_documents(
229
  docs=all_docs,
230
- topics=all_topics,
231
  reduced_embeddings=reduced_embeddings_array,
 
232
  title="",
233
  )
234
  )
235
- logging.info("Plot done ✓")
236
- rows_processed += len(docs)
237
- progress = min(rows_processed / limit, 1.0)
238
- logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
239
- message = (
240
- f"Processing topics for full dataset: {rows_processed} of {limit}"
241
- if full_processing
242
- else f"Processing topics for partial dataset: {rows_processed} of {limit} rows"
243
- )
244
 
 
 
 
 
 
 
 
 
 
245
  yield (
246
  gr.Accordion(open=False),
247
  topics_info,
248
  topic_plot,
249
- gr.Label({"⏳ " + message: progress}, visible=True),
 
 
 
 
 
 
 
250
  "",
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- offset += CHUNK_SIZE
254
- del docs, embeddings, new_model, reduced_embeddings
255
- logging.info("Finished processing all data")
256
-
257
- yield (
258
- gr.Accordion(open=False),
259
- topics_info,
260
- topic_plot,
261
- gr.Label(
262
- {
263
- "✅ " + message: 1.0,
264
- f"⏳ Generating topic names with {model_id}": 0.0,
265
- },
266
- visible=True,
267
- ),
268
- "",
269
- )
270
-
271
- all_topics = base_model.topics_
272
- topics_info = base_model.get_topic_info()
273
 
274
- new_topics_by_text_generation = {}
275
- for _, row in topics_info.iterrows():
276
- logging.info(
277
- f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
278
- )
279
- prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
280
- prompt_messages = [
281
- {
282
- "role": "system",
283
- "content": "You are a helpful, respectful and honest assistant for labeling topics.",
284
- },
285
- {"role": "user", "content": prompt},
286
- ]
287
- output = inference_client.chat_completion(
288
- messages=prompt_messages,
289
- stream=False,
290
- max_tokens=500,
291
- top_p=0.8,
292
- seed=42,
293
  )
294
- inference_response = output.choices[0].message.content
295
- logging.info("Inference response:")
296
- logging.info(inference_response)
297
- new_topics_by_text_generation[row["Topic"]] = inference_response.replace(
298
- "Topic=", ""
299
- ).strip()
300
- base_model.set_topic_labels(new_topics_by_text_generation)
301
-
302
- topics_info = base_model.get_topic_info()
303
-
304
- topic_plot = (
305
- base_model.visualize_document_datamap(
306
- docs=all_docs,
307
- topics=all_topics,
308
- custom_labels=True,
309
- reduced_embeddings=reduced_embeddings_array,
310
- title="",
311
- sub_title=sub_title,
312
- width=800,
313
- height=700,
314
- arrowprops={
315
- "arrowstyle": "wedge,tail_width=0.5",
316
- "connectionstyle": "arc3,rad=0.05",
317
- "linewidth": 0,
318
- "fc": "#33333377",
319
- },
320
- dynamic_label_size=True,
321
- # label_wrap_width=12,
322
- label_over_points=True,
323
- max_font_size=36,
324
- min_font_size=4,
325
  )
326
- if plot_type == "DataMapPlot"
327
- else base_model.visualize_documents(
328
- docs=all_docs,
329
- reduced_embeddings=reduced_embeddings_array,
330
- custom_labels=True,
331
- title="",
 
 
332
  )
333
- )
334
-
335
- dataset_clear_name = dataset.replace("/", "-")
336
- plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
337
- if plot_type == "DataMapPlot":
338
- topic_plot.savefig(plot_png, format="png", dpi=300)
339
- else:
340
- topic_plot.write_image(plot_png)
341
-
342
- custom_labels = base_model.custom_labels_
343
- topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
344
- yield (
345
- gr.Accordion(open=False),
346
- topics_info,
347
- topic_plot,
348
- gr.Label(
349
- {
350
- "✅ " + message: 1.0,
351
- f"✅ Generating topic names with {model_id}": 1.0,
352
- "⏳ Creating Interactive Space": 0.0,
353
- },
354
- visible=True,
355
- ),
356
- "",
357
- )
358
- interactive_plot = datamapplot.create_interactive_plot(
359
- reduced_embeddings_array,
360
- topic_names_array,
361
- hover_text=all_docs,
362
- title=dataset,
363
- sub_title=sub_title.replace(
364
- "dataset",
365
- f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
366
- ),
367
- enable_search=True,
368
- # TODO: Export data to .arrow and also serve it
369
- inline_data=True,
370
- # offline_data_prefix=dataset_clear_name,
371
- initial_zoom_fraction=0.9,
372
- cluster_boundary_polygons=True
373
- )
374
- html_content = str(interactive_plot)
375
- html_file_path = f"{dataset_clear_name}.html"
376
- with open(html_file_path, "w", encoding="utf-8") as html_file:
377
- html_file.write(html_content)
378
-
379
- repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
380
-
381
- space_id = create_space_with_content(
382
- api=api,
383
- repo_id=repo_id,
384
- dataset_id=dataset,
385
- html_file_path=html_file_path,
386
- plot_file_path=plot_png,
387
- space_card=SPACE_REPO_CARD_CONTENT,
388
- token=HF_TOKEN,
389
- )
390
-
391
- space_link = f"https://huggingface.co/spaces/{space_id}"
392
- yield (
393
- gr.Accordion(open=False),
394
- topics_info,
395
- topic_plot,
396
- gr.Label(
397
- {
398
- "✅ " + message: 1.0,
399
- f"✅ Generating topic names with {model_id}": 1.0,
400
- "✅ Creating Interactive Space": 1.0,
401
- },
402
- visible=True,
403
- ),
404
- f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
405
- )
406
- del reduce_umap_model, all_docs, reduced_embeddings_list
407
- del (
408
- base_model,
409
- all_topics,
410
- topics_info,
411
- topic_plot,
412
- topic_names_array,
413
- interactive_plot,
414
- )
415
- cuda.empty_cache()
416
 
417
 
418
  with gr.Blocks() as demo:
@@ -461,11 +468,11 @@ with gr.Blocks() as demo:
461
  generate_button = gr.Button("Generate Topics", variant="primary")
462
 
463
  gr.Markdown("## Data map")
464
- full_topics_generation_label = gr.Label(visible=False, show_label=False)
465
  open_space_label = gr.Markdown()
466
  topics_plot = gr.Plot()
467
- with gr.Accordion("Topics Info", open=False):
468
- topics_df = gr.DataFrame(interactive=False, visible=True)
469
  gr.HTML(
470
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
471
  )
@@ -487,7 +494,7 @@ with gr.Blocks() as demo:
487
  data_details_accordion,
488
  topics_df,
489
  topics_plot,
490
- full_topics_generation_label,
491
  open_space_label,
492
  ],
493
  )
 
37
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
38
  )
39
  USE_CUML = int(os.getenv("USE_CUML", "1"))
 
40
 
41
  # Use cuml lib only if configured
42
  if USE_CUML:
 
52
  )
53
 
54
  api = HfApi(token=HF_TOKEN)
 
55
 
 
56
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
57
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
58
  vectorizer_model = CountVectorizer(stop_words="english")
59
+ representation_model = KeyBERTInspired()
60
 
61
  inference_client = InferenceClient(model_id)
62
 
63
 
64
  def calculate_embeddings(docs):
65
+ return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
66
 
67
 
68
  def calculate_n_neighbors_and_components(n_rows):
 
92
  new_model = BERTopic(
93
  language="english",
94
  # Sub-models
95
+ embedding_model=embedding_model, # Step 1 - Extract embeddings
96
  umap_model=umap_model, # Step 2 - UMAP model
97
  hdbscan_model=hdbscan_model, # Step 3 - Cluster reduced embeddings
98
  vectorizer_model=vectorizer_model, # Step 4 - Tokenize topics
 
166
  "",
167
  )
168
 
169
+ try:
170
+ while offset < limit:
171
+ logging.info(f"----> Getting records from {offset=} with {CHUNK_SIZE=}")
172
+ docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE)
173
+ if not docs:
174
+ break
175
+ logging.info(f"Got {len(docs)} docs ✓")
176
+ embeddings = calculate_embeddings(docs)
177
+ new_model = fit_model(docs, embeddings, n_neighbors, n_components)
178
+
179
+ if base_model is None:
180
+ base_model = new_model
181
+ logging.info(
182
+ f"The following topics are newly found: {base_model.topic_labels_}"
183
+ )
184
+ else:
185
+ updated_model = BERTopic.merge_models([base_model, new_model])
186
+ nr_new_topics = len(set(updated_model.topics_)) - len(
187
+ set(base_model.topics_)
188
+ )
189
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
190
+ logging.info(f"The following topics are newly found: {new_topics}")
191
+ base_model = updated_model
192
+
193
+ logging.info("Reducing embeddings to 2D")
194
+ reduced_embeddings = reduce_umap_model.fit_transform(embeddings)
195
+ reduced_embeddings_list.append(reduced_embeddings)
196
+
197
+ all_docs.extend(docs)
198
+ reduced_embeddings_array = np.vstack(reduced_embeddings_list)
199
+ logging.info("Reducing embeddings to 2D ✓")
200
+
201
+ topics_info = base_model.get_topic_info()
202
+ all_topics = base_model.topics_
203
+ logging.info(f"Preparing topics {plot_type} plot")
204
+
205
+ topic_plot = (
206
+ base_model.visualize_document_datamap(
207
+ docs=all_docs,
208
+ topics=all_topics,
209
+ reduced_embeddings=reduced_embeddings_array,
210
+ title="",
211
+ sub_title=sub_title,
212
+ width=800,
213
+ height=700,
214
+ arrowprops={
215
+ "arrowstyle": "wedge,tail_width=0.5",
216
+ "connectionstyle": "arc3,rad=0.05",
217
+ "linewidth": 0,
218
+ "fc": "#33333377",
219
+ },
220
+ dynamic_label_size=True,
221
+ # label_wrap_width=12,
222
+ label_over_points=True,
223
+ max_font_size=36,
224
+ min_font_size=4,
225
+ )
226
+ if plot_type == "DataMapPlot"
227
+ else base_model.visualize_documents(
228
+ docs=all_docs,
229
+ topics=all_topics,
230
+ reduced_embeddings=reduced_embeddings_array,
231
+ custom_labels=True,
232
+ title="",
233
+ )
234
  )
235
+ logging.info("Plot done ✓")
236
+ rows_processed += len(docs)
237
+ progress = min(rows_processed / limit, 1.0)
238
+ logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
239
+ message = (
240
+ f"Processing topics for full dataset: {rows_processed} of {limit}"
241
+ if full_processing
242
+ else f"Processing topics for partial dataset: {rows_processed} of {limit} rows"
243
  )
 
 
 
244
 
245
+ yield (
246
+ gr.Accordion(open=False),
247
+ topics_info,
248
+ topic_plot,
249
+ gr.Label({"⏳ " + message: progress}, visible=True),
250
+ "",
251
+ )
252
 
253
+ offset += CHUNK_SIZE
254
+ del docs, embeddings, new_model, reduced_embeddings
255
+ logging.info("Finished processing topic modeling data")
256
+
257
+ yield (
258
+ gr.Accordion(open=False),
259
+ topics_info,
260
+ topic_plot,
261
+ gr.Label(
262
+ {
263
+ "✅ " + message: 1.0,
264
+ f"⏳ Generating topic names with {model_id}": 0.0,
265
+ },
266
+ visible=True,
267
+ ),
268
+ "",
269
+ )
270
 
 
271
  all_topics = base_model.topics_
272
+ topics_info = base_model.get_topic_info()
273
+
274
+ new_topics_by_text_generation = {}
275
+ for _, row in topics_info.iterrows():
276
+ logging.info(
277
+ f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
278
+ )
279
+ prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
280
+ prompt_messages = [
281
+ {
282
+ "role": "system",
283
+ "content": "You are a helpful, respectful and honest assistant for labeling topics.",
284
+ },
285
+ {"role": "user", "content": prompt},
286
+ ]
287
+ output = inference_client.chat_completion(
288
+ messages=prompt_messages,
289
+ stream=False,
290
+ max_tokens=500,
291
+ top_p=0.8,
292
+ seed=42,
293
+ )
294
+ inference_response = output.choices[0].message.content
295
+ logging.info("Inference response:")
296
+ logging.info(inference_response)
297
+ new_topics_by_text_generation[row["Topic"]] = inference_response.replace(
298
+ "Topic=", ""
299
+ ).strip()
300
+ base_model.set_topic_labels(new_topics_by_text_generation)
301
+
302
+ topics_info = base_model.get_topic_info()
303
+
304
  topic_plot = (
305
  base_model.visualize_document_datamap(
306
  docs=all_docs,
307
  topics=all_topics,
308
+ custom_labels=True,
309
  reduced_embeddings=reduced_embeddings_array,
310
  title="",
311
  sub_title=sub_title,
 
326
  if plot_type == "DataMapPlot"
327
  else base_model.visualize_documents(
328
  docs=all_docs,
 
329
  reduced_embeddings=reduced_embeddings_array,
330
+ custom_labels=True,
331
  title="",
332
  )
333
  )
 
 
 
 
 
 
 
 
 
334
 
335
+ dataset_clear_name = dataset.replace("/", "-")
336
+ plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
337
+ if plot_type == "DataMapPlot":
338
+ topic_plot.savefig(plot_png, format="png", dpi=300)
339
+ else:
340
+ topic_plot.write_image(plot_png)
341
+
342
+ custom_labels = base_model.custom_labels_
343
+ topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
344
  yield (
345
  gr.Accordion(open=False),
346
  topics_info,
347
  topic_plot,
348
+ gr.Label(
349
+ {
350
+ "✅ " + message: 1.0,
351
+ f"✅ Generating topic names with {model_id}": 1.0,
352
+ "⏳ Creating Interactive Space": 0.0,
353
+ },
354
+ visible=True,
355
+ ),
356
  "",
357
  )
358
+ interactive_plot = datamapplot.create_interactive_plot(
359
+ reduced_embeddings_array,
360
+ topic_names_array,
361
+ hover_text=all_docs,
362
+ title=dataset,
363
+ sub_title=sub_title.replace(
364
+ "dataset",
365
+ f"<a href='https://huggingface.co/datasets/{dataset}/viewer/{config}/{split}' target='_blank'>dataset</a>",
366
+ ),
367
+ enable_search=True,
368
+ # TODO: Export data to .arrow and also serve it
369
+ inline_data=True,
370
+ # offline_data_prefix=dataset_clear_name,
371
+ initial_zoom_fraction=0.8,
372
+ )
373
+ html_content = str(interactive_plot)
374
+ html_file_path = f"{dataset_clear_name}.html"
375
+ with open(html_file_path, "w", encoding="utf-8") as html_file:
376
+ html_file.write(html_content)
377
+
378
+ repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_clear_name}"
379
+
380
+ space_id = create_space_with_content(
381
+ api=api,
382
+ repo_id=repo_id,
383
+ dataset_id=dataset,
384
+ html_file_path=html_file_path,
385
+ plot_file_path=plot_png,
386
+ space_card=SPACE_REPO_CARD_CONTENT,
387
+ token=HF_TOKEN,
388
+ )
389
 
390
+ space_link = f"https://huggingface.co/spaces/{space_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ yield (
393
+ gr.Accordion(open=False),
394
+ topics_info,
395
+ topic_plot,
396
+ gr.Label(
397
+ {
398
+ "✅ " + message: 1.0,
399
+ f"✅ Generating topic names with {model_id}": 1.0,
400
+ "✅ Creating Interactive Space": 1.0,
401
+ },
402
+ visible=True,
403
+ ),
404
+ f"[![Go to interactive plot](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]({space_link})",
 
 
 
 
 
 
405
  )
406
+ del reduce_umap_model, all_docs, reduced_embeddings_list
407
+ del (
408
+ base_model,
409
+ all_topics,
410
+ topics_info,
411
+ topic_names_array,
412
+ interactive_plot,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  )
414
+ cuda.empty_cache()
415
+ except Exception as error:
416
+ return (
417
+ gr.Accordion(open=True),
418
+ gr.DataFrame(value=[], interactive=False, visible=True),
419
+ gr.Plot(value=None, visible=True),
420
+ gr.Label({f"❌ Error: {error}": 0.0}, visible=True),
421
+ "",
422
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
 
425
  with gr.Blocks() as demo:
 
468
  generate_button = gr.Button("Generate Topics", variant="primary")
469
 
470
  gr.Markdown("## Data map")
471
+ progress_label = gr.Label(visible=False, show_label=False)
472
  open_space_label = gr.Markdown()
473
  topics_plot = gr.Plot()
474
+ # with gr.Accordion("Topics Info", open=False):
475
+ topics_df = gr.DataFrame(interactive=False, visible=True)
476
  gr.HTML(
477
  f"<p style='text-align: center; color:orange;'>⚠ This space processes datasets in batches of <b>{CHUNK_SIZE}</b>, with a maximum of <b>{MAX_ROWS}</b> rows. If you need further assistance, please open a new issue in the Community tab.</p>"
478
  )
 
494
  data_details_accordion,
495
  topics_df,
496
  topics_plot,
497
+ progress_label,
498
  open_space_label,
499
  ],
500
  )