|
"""Utility function for gradio/external.py, designed for internal use.""" |
|
|
|
from __future__ import annotations |
|
|
|
import base64 |
|
import math |
|
import re |
|
import warnings |
|
|
|
import httpx |
|
import yaml |
|
from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient |
|
|
|
from gradio import components |
|
|
|
|
|
def get_model_info(model_name, hf_token=None): |
|
hf_api = HfApi(token=hf_token) |
|
print(f"Fetching model from: https://huggingface.co/{model_name}") |
|
|
|
model_info = hf_api.model_info(model_name) |
|
pipeline = model_info.pipeline_tag |
|
tags = model_info.tags |
|
return pipeline, tags |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tabular_examples(model_name: str) -> dict[str, list[float]]: |
|
readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") |
|
if readme.status_code != 200: |
|
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) |
|
example_data = {} |
|
else: |
|
yaml_regex = re.search( |
|
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text |
|
) |
|
if yaml_regex is None: |
|
example_data = {} |
|
else: |
|
example_yaml = next( |
|
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) |
|
) |
|
example_data = example_yaml.get("widget", {}).get("structuredData", {}) |
|
if not example_data: |
|
raise ValueError( |
|
f"No example data found in README.md of {model_name} - Cannot build gradio demo. " |
|
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " |
|
"for a reference on how to provide example data to your model." |
|
) |
|
|
|
for data in example_data.values(): |
|
for i, val in enumerate(data): |
|
if isinstance(val, float) and math.isnan(val): |
|
data[i] = "NaN" |
|
return example_data |
|
|
|
|
|
def cols_to_rows( |
|
example_data: dict[str, list[float | str] | None], |
|
) -> tuple[list[str], list[list[float]]]: |
|
headers = list(example_data.keys()) |
|
n_rows = max(len(example_data[header] or []) for header in headers) |
|
data = [] |
|
for row_index in range(n_rows): |
|
row_data = [] |
|
for header in headers: |
|
col = example_data[header] or [] |
|
if row_index >= len(col): |
|
row_data.append("NaN") |
|
else: |
|
row_data.append(col[row_index]) |
|
data.append(row_data) |
|
return headers, data |
|
|
|
|
|
def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]: |
|
data_column_wise = {} |
|
for i, header in enumerate(incoming_data["headers"]): |
|
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] |
|
return {"inputs": {"data": data_column_wise}} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict: |
|
return {c.label: c.score for c in scores} |
|
|
|
|
|
def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict: |
|
return {c["token_str"]: c["score"] for c in scores} |
|
|
|
|
|
def postprocess_question_answering(answer: dict) -> tuple[str, dict]: |
|
return answer["answer"], {answer["answer"]: answer["score"]} |
|
|
|
|
|
def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict: |
|
return {c["answer"]: c["score"] for c in scores} |
|
|
|
|
|
def zero_shot_classification_wrapper(client: InferenceClient): |
|
def zero_shot_classification_inner(input: str, labels: str, multi_label: bool): |
|
return client.zero_shot_classification( |
|
input, labels.split(","), multi_label=multi_label |
|
) |
|
|
|
return zero_shot_classification_inner |
|
|
|
|
|
def sentence_similarity_wrapper(client: InferenceClient): |
|
def sentence_similarity_inner(input: str, sentences: str): |
|
return client.sentence_similarity(input, sentences.split("\n")) |
|
|
|
return sentence_similarity_inner |
|
|
|
|
|
def text_generation_wrapper(client: InferenceClient): |
|
def text_generation_inner(input: str): |
|
return input + client.text_generation(input) |
|
|
|
return text_generation_inner |
|
|
|
|
|
def conversational_wrapper(client: InferenceClient): |
|
def chat_fn(message, history): |
|
if not history: |
|
history = [] |
|
history.append({"role": "user", "content": message}) |
|
result = client.chat_completion(history) |
|
return result.choices[0].message.content |
|
|
|
return chat_fn |
|
|
|
|
|
def encode_to_base64(r: httpx.Response) -> str: |
|
|
|
base64_repr = base64.b64encode(r.content).decode("utf-8") |
|
data_prefix = ";base64," |
|
|
|
if data_prefix in base64_repr: |
|
return base64_repr |
|
else: |
|
content_type = r.headers.get("content-type") |
|
|
|
if content_type == "application/json": |
|
try: |
|
data = r.json()[0] |
|
content_type = data["content-type"] |
|
base64_repr = data["blob"] |
|
except KeyError as ke: |
|
raise ValueError( |
|
"Cannot determine content type returned by external API." |
|
) from ke |
|
|
|
else: |
|
pass |
|
new_base64 = f"data:{content_type};base64,{base64_repr}" |
|
return new_base64 |
|
|
|
|
|
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): |
|
if len(ner_groups) == 0: |
|
return [(input_string, None)] |
|
|
|
output = [] |
|
end = 0 |
|
prev_end = 0 |
|
|
|
for group in ner_groups: |
|
entity, start, end = group["entity_group"], group["start"], group["end"] |
|
output.append((input_string[prev_end:start], None)) |
|
output.append((input_string[start:end], entity)) |
|
prev_end = end |
|
|
|
output.append((input_string[end:], None)) |
|
return output |
|
|
|
|
|
def token_classification_wrapper(client: InferenceClient): |
|
def token_classification_inner(input: str): |
|
ner_list = client.token_classification(input) |
|
return format_ner_list(input, ner_list) |
|
|
|
return token_classification_inner |
|
|
|
|
|
def object_detection_wrapper(client: InferenceClient): |
|
def object_detection_inner(input: str): |
|
annotations = client.object_detection(input) |
|
formatted_annotations = [ |
|
( |
|
( |
|
a["box"]["xmin"], |
|
a["box"]["ymin"], |
|
a["box"]["xmax"], |
|
a["box"]["ymax"], |
|
), |
|
a["label"], |
|
) |
|
for a in annotations |
|
] |
|
return (input, formatted_annotations) |
|
|
|
return object_detection_inner |
|
|
|
|
|
def chatbot_preprocess(text, state): |
|
if not state: |
|
return text, [], [] |
|
return ( |
|
text, |
|
state["conversation"]["generated_responses"], |
|
state["conversation"]["past_user_inputs"], |
|
) |
|
|
|
|
|
def chatbot_postprocess(response): |
|
chatbot_history = list( |
|
zip( |
|
response["conversation"]["past_user_inputs"], |
|
response["conversation"]["generated_responses"], |
|
strict=False, |
|
) |
|
) |
|
return chatbot_history, response |
|
|
|
|
|
def tabular_wrapper(client: InferenceClient, pipeline: str): |
|
|
|
|
|
|
|
def tabular_inner(data): |
|
if pipeline not in ("tabular_classification", "tabular_regression"): |
|
raise TypeError(f"pipeline type {pipeline!r} not supported") |
|
assert client.model |
|
if pipeline == "tabular_classification": |
|
return client.tabular_classification(data, model=client.model) |
|
else: |
|
return client.tabular_regression(data, model=client.model) |
|
|
|
return tabular_inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def streamline_spaces_interface(config: dict) -> dict: |
|
"""Streamlines the interface config dictionary to remove unnecessary keys.""" |
|
config["inputs"] = [ |
|
components.get_component_instance(component) |
|
for component in config["input_components"] |
|
] |
|
config["outputs"] = [ |
|
components.get_component_instance(component) |
|
for component in config["output_components"] |
|
] |
|
parameters = { |
|
"article", |
|
"description", |
|
"flagging_options", |
|
"inputs", |
|
"outputs", |
|
"title", |
|
} |
|
config = {k: config[k] for k in parameters} |
|
return config |
|
|