Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import streamlit as st | |
from skops import hub_utils | |
from utils import get_rendered_model_card | |
def _add_back_button(): | |
def fn(): | |
st.session_state.screen.state = "edit" | |
st.button("Back", help="continue editing the model card", on_click=fn) | |
def _add_delete_button(): | |
def fn(): | |
st.session_state.screen.state = "start" | |
if "model_card" in st.session_state: | |
del st.session_state["model_card"] | |
if "task_state" in st.session_state: | |
st.session_state.task_state.reset() | |
if "create_repo_name" in st.session_state: | |
del st.session_state["create_repo_name"] | |
if "hf_token" in st.session_state: | |
del st.session_state["hf_token"] | |
st.button("Delete", on_click=fn, help="Start over from scratch (lose all progress)") | |
def _save_model_card(path: Path) -> None: | |
model_card = st.session_state.get("model_card") | |
if model_card: | |
# do not use model_card.save, see doc of get_rendered_model_card | |
rendered = get_rendered_model_card( | |
model_card, hf_path=str(st.session_state.hf_path) | |
) | |
with open(path / "README.md", "w") as f: | |
f.write(rendered) | |
def _display_repo_overview(path: Path) -> None: | |
text = "Files included in the repository:\n" | |
for file in os.listdir(path): | |
size = os.path.getsize(path / file) | |
text += f"- `{file} ({size:,} bytes)`\n" | |
st.markdown(text) | |
def _display_private_box(): | |
tip = ( | |
"Private repositories can only seen by you or members of the same " | |
"organization, see https://huggingface.co/docs/hub/repositories-settings" | |
) | |
st.checkbox("Make repo private", value=True, help=tip, key="create_repo_private") | |
def _repo_id_field(): | |
st.text_input("Name of the repository (e.g. 'User/MyRepo')", key="create_repo_name") | |
def _hf_token_field(): | |
tip = "The Hugging Face token can be found at https://hf.co/settings/token" | |
st.text_input( | |
"Enter your Hugging Face token ('hf_***')", key="hf_token", help=tip | |
) | |
def _create_hf_repo(path, repo_name, hf_token, private): | |
try: | |
hub_utils.push( | |
repo_id=repo_name, | |
source=path, | |
token=hf_token, | |
private=private, | |
create_remote=True, | |
) | |
except Exception as exc: | |
st.error( | |
"Oops, something went wrong, please create an issue. " | |
f"The error message is:\n\n{exc}" | |
) | |
return | |
st.success(f"Successfully created the repo 'https://huggingface.co/{repo_name}'") | |
def _add_create_repo_button(): | |
private = bool(st.session_state.get("create_repo_private")) | |
repo_name = st.session_state.get("create_repo_name") | |
hf_token = st.session_state.get("hf_token") | |
disabled = (not repo_name) or (not hf_token) | |
button_text = "Create a new repository" | |
tip = "Creating a repo requires a name and a token" | |
path = st.session_state.get("hf_path") | |
st.button( | |
button_text, | |
help=tip, | |
disabled=disabled, | |
on_click=_create_hf_repo, | |
args=(path, repo_name, hf_token, private), | |
) | |
if not repo_name: | |
st.info("Repository name is required") | |
if not hf_token: | |
st.info("Token is required") | |
def create_repo_input_form(): | |
if not st.session_state.screen.state == "create_repo": | |
return | |
col_0, col_1, *_ = st.columns([2, 2, 2, 2]) | |
with col_0: | |
_add_back_button() | |
with col_1: | |
_add_delete_button() | |
hf_path = st.session_state.hf_path | |
_save_model_card(hf_path) | |
_display_repo_overview(hf_path) | |
_display_private_box() | |
st.markdown("---") | |
_repo_id_field() | |
_hf_token_field() | |
_add_create_repo_button() | |