helenai's picture
Change examples to ungated repo
b7c6d1d verified
raw
history blame contribute delete
No virus
2.41 kB
import pprint
import subprocess
from pathlib import Path
import gradio as gr
from test_prompt_generator.test_prompt_generator import _preset_tokenizers, generate_prompt
# log system info for debugging purposes
result = subprocess.run(["lscpu"], text=True, capture_output=True)
pprint.pprint(result.stdout)
result = subprocess.run(["pip", "freeze"], text=True, capture_output=True)
pprint.pprint(result.stdout)
def generate(tokenizer_id, num_tokens, prefix=None, source_text=None):
output_path = Path(f"prompt_{num_tokens}.jsonl")
if output_path.exists():
output_path.unlink()
if prefix == "":
prefix = None
prompt = generate_prompt(
tokenizer_id, int(num_tokens), prefix=prefix, source_text=source_text, output_file=output_path
)
if tokenizer_id in _preset_tokenizers:
tokenizer_id = _preset_tokenizers[tokenizer_id]
return prompt, str(output_path), tokenizer_id
demo = gr.Interface(
fn=generate,
title="Test Prompt Generator",
description="Generate prompts with a given number of tokens for testing transformer models. "
"Prompt source: https://archive.org/stream/alicesadventures19033gut/19033.txt",
inputs=[
gr.Dropdown(
label="Tokenizer",
choices=_preset_tokenizers,
value="mistral",
allow_custom_value=True,
info="Select a tokenizer from this list or paste a model_id from a model on the Hugging Face Hub",
),
gr.Number(
label="Number of Tokens", minimum=4, maximum=2048, value=32, info="Enter a number between 4 and 2048."
),
gr.Textbox(
label="Prefix (optional)",
info="If given, the start of the prompt will be this prefix. Example: 'Summarize the following text:'",
),
gr.Textbox(
label="Source text (optional)",
info="By default, prompts will be generated from Alice in Wonderland. Enter text here to use that instead.",
),
],
outputs=[
gr.Textbox(label="prompt", show_copy_button=True),
gr.File(label="Json file"),
gr.Markdown(label="tokenizer"),
],
examples=[
["falcon", 32],
["falcon", 64],
["falcon", 128],
["falcon", 512],
["falcon", 1024],
["falcon", 2048],
],
cache_examples=False,
allow_flagging=False,
)
demo.launch()