Spaces:
Sleeping
Sleeping
import requests | |
import logging | |
import duckdb | |
import numpy as np | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from bertopic import BERTopic | |
from bertopic.representation import ( | |
KeyBERTInspired, | |
TextGeneration, | |
) | |
from umap import UMAP | |
from torch import cuda, bfloat16 | |
from transformers import ( | |
BitsAndBytesConfig, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
) | |
from prompts import REPRESENTATION_PROMPT | |
from hdbscan import HDBSCAN | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sentence_transformers import SentenceTransformer | |
from dotenv import load_dotenv | |
import os | |
# import spaces | |
import gradio as gr | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
MAX_ROWS = 1_000 | |
CHUNK_SIZE = 300 | |
session = requests.Session() | |
sentence_model = SentenceTransformer("all-MiniLM-L6-v2") | |
keybert = KeyBERTInspired() | |
vectorizer_model = CountVectorizer(stop_words="english") | |
model_id = "meta-llama/Llama-2-7b-chat-hf" | |
device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu" | |
logging.info(device) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, # 4-bit quantization | |
bnb_4bit_quant_type="nf4", # Normalized float 4 | |
bnb_4bit_use_double_quant=True, # Second quantization after the first | |
bnb_4bit_compute_dtype=bfloat16, # Computation type | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True, | |
quantization_config=bnb_config, | |
device_map="auto", | |
offload_folder="offload", # Offloading part of the model to CPU to save GPU memory | |
) | |
# Enable gradient checkpointing for memory efficiency during backprop? | |
model.gradient_checkpointing_enable() | |
generator = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
task="text-generation", | |
temperature=0.1, | |
max_new_tokens=200, # Reduced max_new_tokens to limit memory consumption | |
repetition_penalty=1.1, | |
) | |
llama2 = TextGeneration(generator, prompt=REPRESENTATION_PROMPT) | |
representation_model = { | |
"KeyBERT": keybert, | |
"Llama2": llama2, | |
} | |
# TODO: It should be proporcional to the number of rows | |
# For small datasets (1-200 rows) it worked fine with 2 neighbors | |
N_NEIGHBORS = 15 | |
umap_model = UMAP( | |
n_neighbors=N_NEIGHBORS, | |
n_components=5, | |
min_dist=0.0, | |
metric="cosine", | |
random_state=42, | |
) | |
hdbscan_model = HDBSCAN( | |
min_cluster_size=N_NEIGHBORS, | |
metric="euclidean", | |
cluster_selection_method="eom", | |
prediction_data=True, | |
) | |
reduce_umap_model = UMAP( | |
n_neighbors=N_NEIGHBORS, | |
n_components=2, | |
min_dist=0.0, | |
metric="cosine", | |
random_state=42, | |
) | |
global_topic_model = None | |
def get_split_rows(dataset, config, split): | |
config_size = session.get( | |
f"https://datasets-server.huggingface.co/size?dataset={dataset}&config={config}", | |
timeout=20, | |
).json() | |
if "error" in config_size: | |
raise Exception(f"Error fetching config size: {config_size['error']}") | |
split_size = next( | |
(s for s in config_size["size"]["splits"] if s["split"] == split), | |
None, | |
) | |
if split_size is None: | |
raise Exception(f"Error fetching split{split} in config {config}") | |
return split_size["num_rows"] | |
def get_parquet_urls(dataset, config, split): | |
parquet_files = session.get( | |
f"https://datasets-server.huggingface.co/parquet?dataset={dataset}&config={config}&split={split}", | |
timeout=20, | |
).json() | |
if "error" in parquet_files: | |
raise Exception(f"Error fetching parquet files: {parquet_files['error']}") | |
parquet_urls = [file["url"] for file in parquet_files["parquet_files"]] | |
logging.debug(f"Parquet files: {parquet_urls}") | |
return ",".join(f"'{url}'" for url in parquet_urls) | |
def get_docs_from_parquet(parquet_urls, column, offset, limit): | |
SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};" | |
df = duckdb.sql(SQL_QUERY).to_df() | |
logging.debug(f"Dataframe: {df.head(5)}") | |
return df[column].tolist() | |
# @spaces.GPU | |
# TODO: Modify batch size to reduce memory consumption during embedding calculation, which value is better? | |
def calculate_embeddings(docs): | |
return sentence_model.encode(docs, show_progress_bar=True, batch_size=32) | |
# @spaces.GPU | |
def fit_model(docs, embeddings): | |
global global_topic_model | |
new_model = BERTopic( | |
"english", | |
# Sub-models | |
embedding_model=sentence_model, | |
umap_model=umap_model, | |
hdbscan_model=hdbscan_model, | |
representation_model=representation_model, | |
vectorizer_model=vectorizer_model, | |
# Hyperparameters | |
top_n_words=10, | |
verbose=True, | |
min_topic_size=15, # TODO: Should this value be coherent with N_NEIGHBORS? | |
) | |
logging.info("Fitting new model") | |
new_model.fit(docs, embeddings) | |
logging.info("End fitting new model") | |
global_topic_model = new_model | |
logging.info("Global model updated") | |
def generate_topics(dataset, config, split, column, nested_column): | |
logging.info( | |
f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}" | |
) | |
parquet_urls = get_parquet_urls(dataset, config, split) | |
split_rows = get_split_rows(dataset, config, split) | |
logging.info(f"Split rows: {split_rows}") | |
limit = min(split_rows, MAX_ROWS) | |
offset = 0 | |
rows_processed = 0 | |
base_model = None | |
all_docs = [] | |
reduced_embeddings_list = [] | |
topics_info, topic_plot = None, None | |
yield ( | |
gr.DataFrame(interactive=False, visible=True), | |
gr.Plot(visible=True), | |
gr.Label( | |
{f"⚙️ Generating topics {dataset}": rows_processed / limit}, visible=True | |
), | |
) | |
while offset < limit: | |
docs = get_docs_from_parquet(parquet_urls, column, offset, CHUNK_SIZE) | |
if not docs: | |
break | |
logging.info( | |
f"----> Processing chunk: {offset=} {CHUNK_SIZE=} with {len(docs)} docs" | |
) | |
embeddings = calculate_embeddings(docs) | |
fit_model(docs, embeddings) | |
if base_model is None: | |
base_model = global_topic_model | |
else: | |
updated_model = BERTopic.merge_models([base_model, global_topic_model]) | |
nr_new_topics = len(set(updated_model.topics_)) - len( | |
set(base_model.topics_) | |
) | |
new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:] | |
logging.info(f"The following topics are newly found: {new_topics}") | |
base_model = updated_model | |
repr_model_topics = { | |
key: label[0][0].split("\n")[0] | |
for key, label in base_model.get_topics(full=True)["Llama2"].items() | |
} | |
base_model.set_topic_labels(repr_model_topics) | |
reduced_embeddings = reduce_umap_model.fit_transform(embeddings) | |
reduced_embeddings_list.append(reduced_embeddings) | |
all_docs.extend(docs) | |
topics_info = base_model.get_topic_info() | |
topic_plot = base_model.visualize_documents( | |
all_docs, | |
reduced_embeddings=np.vstack(reduced_embeddings_list), | |
custom_labels=True, | |
) | |
logging.info(f"Topics: {repr_model_topics}") | |
rows_processed += len(docs) | |
progress = min(rows_processed / limit, 1.0) | |
logging.info(f"Progress: {progress} % - {rows_processed} of {limit}") | |
yield ( | |
topics_info, | |
topic_plot, | |
gr.Label({f"⚙️ Generating topics {dataset}": progress}, visible=True), | |
) | |
offset += CHUNK_SIZE | |
logging.info("Finished processing all data") | |
yield ( | |
topics_info, | |
topic_plot, | |
gr.Label({f"✅ Generating topics {dataset}": 1.0}, visible=True), | |
) | |
cuda.empty_cache() | |
with gr.Blocks() as demo: | |
gr.Markdown("# 💠 Dataset Topic Discovery 🔭") | |
gr.Markdown("## Select dataset and text column") | |
with gr.Accordion("Data details", open=True): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
dataset_name = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
) | |
subset_dropdown = gr.Dropdown(label="Subset", visible=False) | |
split_dropdown = gr.Dropdown(label="Split", visible=False) | |
with gr.Accordion("Dataset preview", open=False): | |
def embed(name, subset, split): | |
html_code = f""" | |
<iframe | |
src="https://huggingface.co/datasets/{name}/embed/viewer/{subset}/{split}" | |
frameborder="0" | |
width="100%" | |
height="600px" | |
></iframe> | |
""" | |
return gr.HTML(value=html_code) | |
with gr.Row(): | |
text_column_dropdown = gr.Dropdown(label="Text column name") | |
nested_text_column_dropdown = gr.Dropdown( | |
label="Nested text column name", visible=False | |
) | |
generate_button = gr.Button("Generate Topics", variant="primary") | |
gr.Markdown("## Datamap") | |
full_topics_generation_label = gr.Label(visible=False, show_label=False) | |
topics_plot = gr.Plot() | |
with gr.Accordion("Topics Info", open=False): | |
topics_df = gr.DataFrame(interactive=False, visible=True) | |
generate_button.click( | |
generate_topics, | |
inputs=[ | |
dataset_name, | |
subset_dropdown, | |
split_dropdown, | |
text_column_dropdown, | |
nested_text_column_dropdown, | |
], | |
outputs=[topics_df, topics_plot, full_topics_generation_label], | |
) | |
def _resolve_dataset_selection( | |
dataset: str, default_subset: str, default_split: str, text_feature | |
): | |
if "/" not in dataset.strip().strip("/"): | |
return { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
text_column_dropdown: gr.Dropdown(label="Text column name"), | |
nested_text_column_dropdown: gr.Dropdown(visible=False), | |
} | |
info_resp = session.get( | |
f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=20 | |
).json() | |
if "error" in info_resp: | |
return { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
text_column_dropdown: gr.Dropdown(label="Text column name"), | |
nested_text_column_dropdown: gr.Dropdown(visible=False), | |
} | |
subsets: list[str] = list(info_resp["dataset_info"]) | |
subset = default_subset if default_subset in subsets else subsets[0] | |
splits: list[str] = list(info_resp["dataset_info"][subset]["splits"]) | |
split = default_split if default_split in splits else splits[0] | |
features = info_resp["dataset_info"][subset]["features"] | |
def _is_string_feature(feature): | |
return isinstance(feature, dict) and feature.get("dtype") == "string" | |
text_features = [ | |
feature_name | |
for feature_name, feature in features.items() | |
if _is_string_feature(feature) | |
] | |
nested_features = [ | |
feature_name | |
for feature_name, feature in features.items() | |
if isinstance(feature, dict) | |
and isinstance(next(iter(feature.values())), dict) | |
] | |
nested_text_features = [ | |
feature_name | |
for feature_name in nested_features | |
if any( | |
_is_string_feature(nested_feature) | |
for nested_feature in features[feature_name].values() | |
) | |
] | |
if not text_feature: | |
return { | |
subset_dropdown: gr.Dropdown( | |
value=subset, choices=subsets, visible=len(subsets) > 1 | |
), | |
split_dropdown: gr.Dropdown( | |
value=split, choices=splits, visible=len(splits) > 1 | |
), | |
text_column_dropdown: gr.Dropdown( | |
choices=text_features + nested_text_features, | |
label="Text column name", | |
), | |
nested_text_column_dropdown: gr.Dropdown(visible=False), | |
} | |
if text_feature in nested_text_features: | |
nested_keys = [ | |
feature_name | |
for feature_name, feature in features[text_feature].items() | |
if _is_string_feature(feature) | |
] | |
return { | |
subset_dropdown: gr.Dropdown( | |
value=subset, choices=subsets, visible=len(subsets) > 1 | |
), | |
split_dropdown: gr.Dropdown( | |
value=split, choices=splits, visible=len(splits) > 1 | |
), | |
text_column_dropdown: gr.Dropdown( | |
choices=text_features + nested_text_features, | |
label="Text column name", | |
), | |
nested_text_column_dropdown: gr.Dropdown( | |
value=nested_keys[0], | |
choices=nested_keys, | |
label="Nested text column name", | |
visible=True, | |
), | |
} | |
return { | |
subset_dropdown: gr.Dropdown( | |
value=subset, choices=subsets, visible=len(subsets) > 1 | |
), | |
split_dropdown: gr.Dropdown( | |
value=split, choices=splits, visible=len(splits) > 1 | |
), | |
text_column_dropdown: gr.Dropdown( | |
choices=text_features + nested_text_features, label="Text column name" | |
), | |
nested_text_column_dropdown: gr.Dropdown(visible=False), | |
} | |
def show_input_from_subset_dropdown(dataset: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset="default", default_split="train", text_feature=None | |
) | |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset=subset, default_split="train", text_feature=None | |
) | |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
return _resolve_dataset_selection( | |
dataset, default_subset=subset, default_split=split, text_feature=None | |
) | |
def show_input_from_text_column_dropdown( | |
dataset: str, subset: str, split: str, text_column | |
) -> dict: | |
return _resolve_dataset_selection( | |
dataset, | |
default_subset=subset, | |
default_split=split, | |
text_feature=text_column, | |
) | |
demo.launch() | |