helenai's picture
Update app.py
7fb5607 verified
import io
import json
import re
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer
tokenizers = {
"bert": "google-bert/bert-base-uncased",
"bge-en": "BAAI/bge-base-en-v1.5",
"bge-zh": "BAAI/bge-base-zh-v1.5",
"blenderbot": "facebook/blenderbot-3B",
"bloom": "bigscience/bloom-560m",
"bloomz": "bigscience/bloomz-7b1",
"chatglm3": "THUDM/chatglm3-6b",
"falcon": "tiiuae/falcon-7b",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"gpt-neox": "EleutherAI/gpt-neox-20b",
"llama": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"magicoder": "ise-uiuc/Magicoder-S-DS-6.7B",
"mistral": "echarlaix/tiny-random-mistral",
"mpt": "mosaicml/mpt-7b",
"opt": "facebook/opt-2.7b",
"phi-2": "microsoft/phi-2",
"pythia": "EleutherAI/pythia-1.4b-deduped",
"qwen": "Qwen/Qwen1.5-7B-Chat",
"redpajama": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
"roberta": "FacebookAI/roberta-base",
"starcoder": "bigcode/starcoder2-7b",
"t5": "google-t5/t5-base",
"vicuna": "lmsys/vicuna-7b-v1.5",
"zephyr": "HuggingFaceH4/zephyr-7b-beta",
}
tokenizers = list(tokenizers.values())
def plot_histogram(data):
plt.hist(data)
plt.title("Histogram of number of tokens per dataset item")
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
im = Image.open(buf)
return im
def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
tokencounter = []
wordcounter = []
charcounter = []
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if config == "":
config is None
dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
pattern = r"[a-zA-Z]+"
for item in dataset:
tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
tokencounter.append(len(tokens))
charcounter.append(len(item[column]))
# not 100% accurate but good enough
words = re.findall(pattern, item[column])
wordcounter.append(len(words))
percentiles = [0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
df = pd.DataFrame(tokencounter).describe(percentiles=percentiles).T
df.insert(0, "type", "tokens")
dfc = pd.DataFrame(charcounter).describe(percentiles=percentiles).T
dfc.insert(0, "type", "chars")
dfw = pd.DataFrame(wordcounter).describe(percentiles=percentiles).T
dfw.insert(0, "type", "words")
df.loc[-1] = dfw.values[0]
df.index = df.index + 1 # shifting index
df.loc[-1] = dfc.values[0]
df = df.round(1)
df.drop("count", axis=1, inplace=True)
return plot_histogram(tokencounter), df
demo = gr.Interface(
fn=count,
title="Dataset token counts and distribution",
inputs=[
gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
gr.Textbox(label="Dataset"),
gr.Textbox(label="Config"),
gr.Textbox(label="Split"),
gr.Textbox(label="Column"),
gr.Checkbox(label="Add special tokens", value=True),
],
outputs=[
gr.Image(),
gr.Dataframe(label="Token, word and character counts per dataset item"),
],
examples=[
["tiiuae/falcon-7b", "gsarti/flores_101", "eng", "dev", "sentence"],
["tiiuae/falcon-7b", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
["tiiuae/falcon-7b", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
["tiiuae/falcon-7b", "gsm8k", "main", "test", "question"],
["tiiuae/falcon-7b", "locuslab/TOFU", "world_facts", "train", "question"],
["tiiuae/falcon-7b", "imdb", "", "test", "text"],
["tiiuae/falcon-7b", "wikitext", "wikitext-2-v1", "validation", "text"],
["tiiuae/falcon-7b", "zeroshot/twitter-financial-news-sentiment", "", "validation", "text"],
["BAAI/bge-base-en-v1.5", "PolyAI/banking77", "", "test", "text"],
["BAAI/bge-base-en-v1.5", "mteb/amazon_massive_intent", "en", "test", "text"],
["BAAI/bge-base-en-v1.5", "mteb/sts16-sts", "", "test", "sentence1"],
],
cache_examples=True,
)
demo.launch()