"""Start page of the app This page is used to initialize a model card that is either: 1. based on the skops template 2. empty 3. loads an existing model card Optionally, users can add a model file, data, requirements, and choose a task. """ import glob import io import os import pickle import shutil from pathlib import Path from tempfile import mkdtemp import pandas as pd import sklearn import streamlit as st from huggingface_hub import snapshot_download from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError from sklearn.base import BaseEstimator from sklearn.dummy import DummyClassifier import skops.io as sio from skops import card, hub_utils tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files description = """Create a Hugging Face model repository for scikit learn models This page aims to provide a simple interface to use the [`skops`](https://skops.readthedocs.io/) model card and HF Hub creation utilities. """ def load_model() -> None: if st.session_state.get("model_file") is None: st.session_state.model = DummyClassifier() return bytes_data = st.session_state.model_file.getvalue() if st.session_state.model_file.name.endswith("skops"): model = sio.loads(bytes_data, trusted=True) else: model = pickle.loads(bytes_data) assert isinstance(model, BaseEstimator), "model must be an sklearn model" st.session_state.model = model def load_data() -> None: if st.session_state.get("data_file"): bytes_data = io.BytesIO(st.session_state.data_file.getvalue()) df = pd.read_csv(bytes_data) else: df = pd.DataFrame([]) st.session_state.data = df def _clear_repo(path: str) -> None: for file_path in glob.glob(str(Path(path) / "*")): if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) def init_repo() -> None: path = st.session_state.hf_path _clear_repo(path) requirements = [] task = "tabular-classification" data = pd.DataFrame([]) if "requirements" in st.session_state: requirements = st.session_state.requirements.splitlines() if "task" in st.session_state: task = st.session_state.task if "data_file" in st.session_state: load_data() data = st.session_state.data if task.startswith("text") and isinstance(data, pd.DataFrame): data = data.values.tolist() try: file_name = tmp_path / "model.skops" sio.dump(st.session_state.model, file_name) hub_utils.init( model=file_name, dst=path, task=task, data=data, requirements=requirements, ) except Exception as exc: print("Uh oh, something went wrong when initializing the repo:", exc) def create_skops_model_card() -> None: init_repo() metadata = card.metadata_from_config(st.session_state.hf_path) model_card = card.Card(model=st.session_state.model, metadata=metadata) st.session_state.model_card = model_card st.session_state.model_card_type = "skops" st.session_state.screen.state = "edit" def create_empty_model_card() -> None: init_repo() metadata = card.metadata_from_config(st.session_state.hf_path) model_card = card.Card( model=st.session_state.model, metadata=metadata, template=None ) model_card.add(**{"Untitled": "[More Information Needed]"}) st.session_state.model_card = model_card st.session_state.model_card_type = "empty" st.session_state.screen.state = "edit" def create_hf_model_card() -> None: repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"') if not repo_id: return try: allow_patterns = [ "*.md", ".txt", "*.png", "*.gif", "*.jpg", "*.jpeg", "*.bmp", "*.webp", ] path = snapshot_download(repo_id, allow_patterns=allow_patterns) except (HFValidationError, RepositoryNotFoundError): st.error( f"Repository '{repo_id}' could not be found on HF Hub, " "please check that the repo ID is correct." ) return # move everything to the hf_path and working dir hf_path = st.session_state.hf_path shutil.copytree(path, hf_path, dirs_exist_ok=True) shutil.copytree(path, ".", dirs_exist_ok=True) model_card = card.parse_modelcard(hf_path / "README.md") st.session_state.model_card = model_card st.session_state.model_card_type = "loaded" st.session_state.screen.state = "edit" def add_help_button(): def fn(): st.session_state.screen.state = "help" st.button( "Get help", on_click=fn, help="Detailed explanation of this space", key="get_help", ) def start_input_form(): if "model" not in st.session_state: st.session_state.model = DummyClassifier() if "data" not in st.session_state: st.session_state.data = pd.DataFrame([]) if "model_card" not in st.session_state: st.session_state.model_card = None st.markdown(description) add_help_button() st.markdown("---") st.text( "Upload an sklearn model (strongly recommended)\n" "The model can be used to automatically populate fields in the model card." ) if not st.session_state.get("model_file"): st.file_uploader( "Upload an sklearn model (pickle or skops format)", on_change=load_model, key="model_file", ) st.markdown("---") st.text( "Upload samples from your data (in csv format)\n" "This sample data can be attached to the metadata of the model card" ) st.file_uploader( "Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file" ) st.markdown("---") st.selectbox( label="Choose the task type", options=[ "tabular-classification", "tabular-regression", "text-classification", "text-regression", ], key="task", on_change=init_repo, ) st.markdown("---") st.text_area( label="Requirements", value=f"scikit-learn=={sklearn.__version__}\n", key="requirements", on_change=init_repo, ) st.markdown("---") st.markdown("Choose one of the options below to get started:") col_0, col_1, col_2 = st.columns([2, 2, 2]) with col_0: st.button("Create a new skops model card", on_click=create_skops_model_card) with col_1: st.button("Create a new empty model card", on_click=create_empty_model_card) with col_2: with st.form("Load existing model card from HF Hub", clear_on_submit=False): st.markdown("Load existing model card from HF Hub") st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id") st.form_submit_button("Load", on_click=create_hf_model_card) start_input_form()