lhoestq's picture
lhoestq HF staff
Update app.py
381fd46 verified
from collections import Counter
from itertools import count, groupby, islice
from operator import itemgetter
from typing import Any, Iterable, TypeVar
import gradio as gr
import requests
import pandas as pd
from datasets import Features
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from requests.adapters import HTTPAdapter, Retry
from analyze import PresidioEntity, analyzer, get_column_description, get_columns_with_strings, mask, presidio_scan_entities
MAX_ROWS = 100
T = TypeVar("T")
session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
DEFAULT_PRESIDIO_ENTITIES = sorted([
'PERSON',
'CREDIT_CARD',
'US_SSN',
'US_DRIVER_LICENSE',
'PHONE_NUMBER',
'US_PASSPORT',
'EMAIL_ADDRESS',
'IP_ADDRESS',
'US_BANK_NUMBER',
'IBAN_CODE',
'EMAIL',
])
def stream_rows(dataset: str, config: str, split: str) -> Iterable[dict[str, Any]]:
batch_size = 100
for i in count():
rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json()
if "error" in rows_resp:
raise RuntimeError(rows_resp["error"])
if not rows_resp["rows"]:
break
for row_item in rows_resp["rows"]:
yield row_item["row"]
class track_iter:
def __init__(self, it: Iterable[T]):
self.it = it
self.next_idx = 0
def __iter__(self) -> T:
for item in self.it:
self.next_idx += 1
yield item
def presidio_report(presidio_entities: list[PresidioEntity], next_row_idx: int, num_rows: int) -> dict[str, float]:
title = f"Scan finished: {len(presidio_entities)} entities found" if num_rows == next_row_idx else "Scan in progress..."
counter = Counter([title] * next_row_idx)
for row_idx, presidio_entities_per_row in groupby(presidio_entities, itemgetter("row_idx")):
counter.update(set("% of rows with " + presidio_entity["type"] for presidio_entity in presidio_entities_per_row))
return dict((presidio_entity_type, presidio_entity_type_row_count / num_rows) for presidio_entity_type, presidio_entity_type_row_count in counter.most_common())
def analyze_dataset(dataset: str, enabled_presidio_entities: list[str] = DEFAULT_PRESIDIO_ENTITIES, show_texts_without_masks: bool = False) -> pd.DataFrame:
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
if "error" in info_resp:
yield "❌ " + info_resp["error"], pd.DataFrame()
return
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
features = Features.from_dict(info_resp["dataset_info"][config]["features"])
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(iter(info_resp["dataset_info"][config]["splits"]))
num_rows = min(info_resp["dataset_info"][config]["splits"][split]["num_examples"], MAX_ROWS)
scanned_columns = get_columns_with_strings(features)
columns_descriptions = [
get_column_description(column_name, features[column_name]) for column_name in scanned_columns
]
rows = track_iter(islice(stream_rows(dataset, config, split), MAX_ROWS))
presidio_entities = []
for presidio_entity in presidio_scan_entities(
rows, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions
):
if not show_texts_without_masks:
presidio_entity["text"] = mask(presidio_entity["text"])
if presidio_entity["type"] in enabled_presidio_entities:
presidio_entities.append(presidio_entity)
yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities)
yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities)
with gr.Blocks(css=".table {border-collapse: separate}") as demo: # custom CSS to fix a bug with gr.DataFrame, see https://github.com/radames/gradio-custom-components/issues/1
gr.Markdown("# Scan datasets using Presidio")
gr.Markdown("The space takes an HF dataset name as an input, and returns the list of entities detected by Presidio in the first samples.")
inputs = [
HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Search for dataset id on Huggingface",
search_type="dataset",
),
gr.CheckboxGroup(
label="Presidio entities",
choices=sorted(analyzer.get_supported_entities()),
value=DEFAULT_PRESIDIO_ENTITIES,
interactive=True,
),
gr.Checkbox(label="Show texts without masks", value=False),
]
button = gr.Button("Run Presidio Scan")
outputs = [
gr.Label(show_label=False),
gr.DataFrame(),
]
button.click(analyze_dataset, inputs, outputs)
gr.Examples(
[
["microsoft/orca-math-word-problems-200k"],
["tatsu-lab/alpaca"],
["Anthropic/hh-rlhf"],
["OpenAssistant/oasst1"],
["sidhq/email-thread-summary"],
["lhoestq/fake_name_and_ssn"]
],
inputs,
outputs,
fn=analyze_dataset,
run_on_click=True,
cache_examples=False,
)
demo.launch()