|
|
|
import gradio as gr |
|
import requests |
|
from huggingface_hub import HfApi |
|
from huggingface_hub.errors import RepositoryNotFoundError |
|
import pandas as pd |
|
import plotly.express as px |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
|
|
HF_API = HfApi() |
|
|
|
|
|
def format_repo_size(r_size: int) -> str: |
|
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"} |
|
order = 0 |
|
while r_size >= 1024 and order < len(units) - 1: |
|
r_size /= 1024 |
|
order += 1 |
|
return f"{r_size:.2f} {units[order]}" |
|
|
|
|
|
def repo_files(r_type: str, r_id: str) -> dict: |
|
r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True) |
|
files = {} |
|
for sibling in r_info.siblings: |
|
ext = sibling.rfilename.split(".")[-1] |
|
if ext in files: |
|
files[ext]["size"] += sibling.size |
|
files[ext]["count"] += 1 |
|
else: |
|
files[ext] = {} |
|
files[ext]["size"] = sibling.size |
|
files[ext]["count"] = 1 |
|
return files |
|
|
|
|
|
def repo_size(r_type, r_id): |
|
r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type) |
|
repo_sizes = {} |
|
for branch in r_refs.branches: |
|
try: |
|
response = requests.get( |
|
f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}", |
|
timeout=1000, |
|
) |
|
response = response.json() |
|
|
|
except Exception: |
|
response = {} |
|
if response.get("error") and "restricted" in response.get("error"): |
|
gr.Warning(f"Branch information for {r_id} not available.") |
|
return {} |
|
size = response.get("size") |
|
if size is not None: |
|
repo_sizes[branch.name] = size |
|
return repo_sizes |
|
|
|
|
|
def get_repo_info(r_type, r_id): |
|
try: |
|
repo_sizes = repo_size(r_type, r_id) |
|
repo_files_info = repo_files(r_type, r_id) |
|
except RepositoryNotFoundError: |
|
gr.Warning( |
|
"Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository." |
|
) |
|
return ( |
|
gr.Row(visible=False), |
|
gr.Dataframe(visible=False), |
|
gr.Plot(visible=False), |
|
gr.Row(visible=False), |
|
gr.Dataframe(visible=False), |
|
) |
|
rf_sizes_df = ( |
|
pd.DataFrame(repo_files_info) |
|
.T.reset_index(names="ext") |
|
.sort_values(by="size", ascending=False) |
|
) |
|
|
|
if not repo_sizes: |
|
r_sizes_component = gr.Dataframe(visible=False) |
|
b_block = gr.Row(visible=False) |
|
else: |
|
r_sizes_df = pd.DataFrame(repo_sizes, index=["size"]).T.reset_index( |
|
names="branch" |
|
) |
|
r_sizes_df["formatted_size"] = r_sizes_df["size"].apply(format_repo_size) |
|
r_sizes_df.columns = ["Branch", "bytes", "Size"] |
|
r_sizes_component = gr.Dataframe( |
|
value=r_sizes_df[["Branch", "Size"]], visible=True |
|
) |
|
b_block = gr.Row(visible=True) |
|
|
|
rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size) |
|
rf_sizes_df.columns = ["Extension", "bytes", "Count", "Size"] |
|
rf_sizes_plot = px.pie( |
|
rf_sizes_df, |
|
values="bytes", |
|
names="Extension", |
|
hover_data=["Size"], |
|
title=f"File Distribution in {r_id}", |
|
hole=0.3, |
|
) |
|
return ( |
|
gr.Row(visible=True), |
|
gr.Dataframe( |
|
value=rf_sizes_df[["Extension", "Count", "Size"]], |
|
visible=True, |
|
), |
|
gr.Plot(rf_sizes_plot, visible=True), |
|
b_block, |
|
r_sizes_component, |
|
) |
|
|
|
|
|
with gr.Blocks(theme="ocean") as demo: |
|
gr.Markdown("# Repository Information") |
|
gr.Markdown( |
|
"Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's files and branches." |
|
) |
|
with gr.Blocks(): |
|
|
|
repo_id = HuggingfaceHubSearch( |
|
label="Hub Model ID", |
|
placeholder="Search for model id on Huggingface", |
|
search_type=["model", "dataset"], |
|
) |
|
repo_type = gr.Radio( |
|
choices=["model", "dataset"], |
|
label="Repository Type", |
|
value="model", |
|
) |
|
search_button = gr.Button(value="Search") |
|
with gr.Blocks(): |
|
with gr.Row(visible=False) as results_block: |
|
with gr.Column(): |
|
gr.Markdown("## File Information") |
|
with gr.Row(): |
|
file_info = gr.Dataframe(visible=False) |
|
file_info_plot = gr.Plot(visible=False) |
|
with gr.Row(visible=False) as branch_block: |
|
with gr.Column(): |
|
gr.Markdown("## Branch Sizes") |
|
branch_sizes = gr.Dataframe(visible=False) |
|
|
|
search_button.click( |
|
get_repo_info, |
|
inputs=[repo_type, repo_id], |
|
outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes], |
|
) |
|
|
|
demo.launch() |
|
|