Spaces:
Runtime error
Runtime error
| """Start page of the app | |
| This page is used to initialize a model card that is either: | |
| 1. based on the skops template | |
| 2. empty | |
| 3. loads an existing model card | |
| Optionally, users can add a model file, data, requirements, and choose a task. | |
| """ | |
| import glob | |
| import io | |
| import os | |
| import pickle | |
| import shutil | |
| from pathlib import Path | |
| from tempfile import mkdtemp | |
| import pandas as pd | |
| import sklearn | |
| import streamlit as st | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError | |
| from sklearn.base import BaseEstimator | |
| from sklearn.dummy import DummyClassifier | |
| import skops.io as sio | |
| from skops import card, hub_utils | |
| tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files | |
| description = """Create a Hugging Face model repository for scikit learn models | |
| This page aims to provide a simple interface to use the | |
| [`skops`](https://skops.readthedocs.io/) model card and HF Hub creation | |
| utilities. | |
| """ | |
| def load_model() -> None: | |
| if st.session_state.get("model_file") is None: | |
| st.session_state.model = DummyClassifier() | |
| return | |
| bytes_data = st.session_state.model_file.getvalue() | |
| if st.session_state.model_file.name.endswith("skops"): | |
| model = sio.loads(bytes_data, trusted=True) | |
| else: | |
| model = pickle.loads(bytes_data) | |
| assert isinstance(model, BaseEstimator), "model must be an sklearn model" | |
| st.session_state.model = model | |
| def load_data() -> None: | |
| if st.session_state.get("data_file"): | |
| bytes_data = io.BytesIO(st.session_state.data_file.getvalue()) | |
| df = pd.read_csv(bytes_data) | |
| else: | |
| df = pd.DataFrame([]) | |
| st.session_state.data = df | |
| def _clear_repo(path: str) -> None: | |
| for file_path in glob.glob(str(Path(path) / "*")): | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| def init_repo() -> None: | |
| path = st.session_state.hf_path | |
| _clear_repo(path) | |
| requirements = [] | |
| task = "tabular-classification" | |
| data = pd.DataFrame([]) | |
| if "requirements" in st.session_state: | |
| requirements = st.session_state.requirements.splitlines() | |
| if "task" in st.session_state: | |
| task = st.session_state.task | |
| if "data_file" in st.session_state: | |
| load_data() | |
| data = st.session_state.data | |
| if task.startswith("text") and isinstance(data, pd.DataFrame): | |
| data = data.values.tolist() | |
| try: | |
| file_name = tmp_path / "model.skops" | |
| sio.dump(st.session_state.model, file_name) | |
| hub_utils.init( | |
| model=file_name, | |
| dst=path, | |
| task=task, | |
| data=data, | |
| requirements=requirements, | |
| ) | |
| except Exception as exc: | |
| print("Uh oh, something went wrong when initializing the repo:", exc) | |
| def create_skops_model_card() -> None: | |
| init_repo() | |
| metadata = card.metadata_from_config(st.session_state.hf_path) | |
| model_card = card.Card(model=st.session_state.model, metadata=metadata) | |
| st.session_state.model_card = model_card | |
| st.session_state.model_card_type = "skops" | |
| st.session_state.screen.state = "edit" | |
| def create_empty_model_card() -> None: | |
| init_repo() | |
| metadata = card.metadata_from_config(st.session_state.hf_path) | |
| model_card = card.Card( | |
| model=st.session_state.model, metadata=metadata, template=None | |
| ) | |
| model_card.add(**{"Untitled": "[More Information Needed]"}) | |
| st.session_state.model_card = model_card | |
| st.session_state.model_card_type = "empty" | |
| st.session_state.screen.state = "edit" | |
| def create_hf_model_card() -> None: | |
| repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"') | |
| if not repo_id: | |
| return | |
| try: | |
| allow_patterns = [ | |
| "*.md", | |
| ".txt", | |
| "*.png", | |
| "*.gif", | |
| "*.jpg", | |
| "*.jpeg", | |
| "*.bmp", | |
| "*.webp", | |
| ] | |
| path = snapshot_download(repo_id, allow_patterns=allow_patterns) | |
| except (HFValidationError, RepositoryNotFoundError): | |
| st.error( | |
| f"Repository '{repo_id}' could not be found on HF Hub, " | |
| "please check that the repo ID is correct." | |
| ) | |
| return | |
| # move everything to the hf_path and working dir | |
| hf_path = st.session_state.hf_path | |
| shutil.copytree(path, hf_path, dirs_exist_ok=True) | |
| shutil.copytree(path, ".", dirs_exist_ok=True) | |
| model_card = card.parse_modelcard(hf_path / "README.md") | |
| st.session_state.model_card = model_card | |
| st.session_state.model_card_type = "loaded" | |
| st.session_state.screen.state = "edit" | |
| def add_help_button(): | |
| def fn(): | |
| st.session_state.screen.state = "help" | |
| st.button( | |
| "Get help", | |
| on_click=fn, | |
| help="Detailed explanation of this space", | |
| key="get_help", | |
| ) | |
| def start_input_form(): | |
| if "model" not in st.session_state: | |
| st.session_state.model = DummyClassifier() | |
| if "data" not in st.session_state: | |
| st.session_state.data = pd.DataFrame([]) | |
| if "model_card" not in st.session_state: | |
| st.session_state.model_card = None | |
| st.markdown(description) | |
| add_help_button() | |
| st.markdown("---") | |
| st.text( | |
| "Upload an sklearn model (strongly recommended)\n" | |
| "The model can be used to automatically populate fields in the model card." | |
| ) | |
| if not st.session_state.get("model_file"): | |
| st.file_uploader( | |
| "Upload an sklearn model (pickle or skops format)", | |
| on_change=load_model, | |
| key="model_file", | |
| ) | |
| st.markdown("---") | |
| st.text( | |
| "Upload samples from your data (in csv format)\n" | |
| "This sample data can be attached to the metadata of the model card" | |
| ) | |
| st.file_uploader( | |
| "Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file" | |
| ) | |
| st.markdown("---") | |
| st.selectbox( | |
| label="Choose the task type", | |
| options=[ | |
| "tabular-classification", | |
| "tabular-regression", | |
| "text-classification", | |
| "text-regression", | |
| ], | |
| key="task", | |
| on_change=init_repo, | |
| ) | |
| st.markdown("---") | |
| st.text_area( | |
| label="Requirements", | |
| value=f"scikit-learn=={sklearn.__version__}\n", | |
| key="requirements", | |
| on_change=init_repo, | |
| ) | |
| st.markdown("---") | |
| st.markdown("Choose one of the options below to get started:") | |
| col_0, col_1, col_2 = st.columns([2, 2, 2]) | |
| with col_0: | |
| st.button("Create a new skops model card", on_click=create_skops_model_card) | |
| with col_1: | |
| st.button("Create a new empty model card", on_click=create_empty_model_card) | |
| with col_2: | |
| with st.form("Load existing model card from HF Hub", clear_on_submit=False): | |
| st.markdown("Load existing model card from HF Hub") | |
| st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id") | |
| st.form_submit_button("Load", on_click=create_hf_model_card) | |
| start_input_form() | |