Spaces:
Runtime error
Runtime error
| # HF space creator starting from an sklearn model | |
| from __future__ import annotations | |
| import base64 | |
| import glob | |
| import io | |
| import os | |
| import pickle | |
| import re | |
| import shutil | |
| from pathlib import Path | |
| from tempfile import mkdtemp | |
| import pandas as pd | |
| import sklearn | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from sklearn.base import BaseEstimator | |
| import skops.io as sio | |
| from skops import card, hub_utils | |
| st.set_page_config(layout="wide") | |
| st.title("Skops space creator for sklearn") | |
| PLACEHOLDER = "[More Information Needed]" | |
| PLOT_PREFIX = "__plot__:" | |
| # store session state | |
| if "custom_sections" not in st.session_state: | |
| st.session_state.custom_sections = {} | |
| # the tmp_path is used to upload the sklearn model to | |
| tmp_path = Path(mkdtemp(prefix="skops-")) | |
| # the hf_path is the actual repo used for init() | |
| hf_path = Path(mkdtemp(prefix="skops-")) | |
| # a hacky way to "persist" custom sections | |
| CUSTOM_SECTIONS_CACHE_FILE = ".custom-sections.json" | |
| def _clear_custom_section_cache(): | |
| st.session_state.custom_sections.clear() | |
| def _remove_custom_section(key): | |
| section_names = list(st.session_state.custom_sections.keys()) | |
| for section_name in section_names: | |
| if ( | |
| (section_name == key) | |
| or section_name.startswith(key + "/") | |
| or section_name.startswith(key + " /") | |
| ): | |
| del st.session_state.custom_sections[section_name] | |
| def _clear_repo(path): | |
| for file_path in glob.glob(str(Path(path) / "*")): | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| def _write_plot(plot_name, plot_file): | |
| with open(plot_name, "wb") as f: | |
| f.write(plot_file) | |
| def init_repo(): | |
| _clear_repo(hf_path) | |
| try: | |
| file_name = tmp_path / "model.skops" | |
| sio.dump(model, file_name) | |
| reqs = [r.strip().rstrip(",") for r in requirements.splitlines()] | |
| hub_utils.init( | |
| model=file_name, | |
| dst=hf_path, | |
| task=task, | |
| data=data, | |
| requirements=reqs, | |
| ) | |
| except Exception as exc: | |
| print("Uh oh, something went wrong when initializing the repo:", exc) | |
| def load_model(): | |
| if model_file is None: | |
| return | |
| bytes_data = model_file.getvalue() | |
| model = pickle.loads(bytes_data) | |
| assert isinstance(model, BaseEstimator), "model must be an sklearn model" | |
| return model | |
| def load_data(): | |
| if data_file is None: | |
| return | |
| bytes_data = io.BytesIO(data_file.getvalue()) | |
| df = pd.read_csv(bytes_data) | |
| return df | |
| def _parse_metrics(metrics): | |
| metrics_table = {} | |
| for line in metrics.splitlines(): | |
| line = line.strip() | |
| name, _, val = line.partition("=") | |
| try: | |
| # try to coerce to float but don't error if it fails | |
| val = float(val.strip()) | |
| except ValueError: | |
| pass | |
| metrics_table[name.strip()] = val | |
| return metrics_table | |
| def _load_model_card_from_repo(repo_id: str) -> Card: | |
| path = hf_hub_download(repo_id, "README.md") | |
| return card.parse_modelcard(path) | |
| def _create_model_card(): | |
| init_repo() | |
| if model_card_repo: # load existing model card | |
| model_card = _load_model_card_from_repo(model_card_repo) | |
| else: # create new model card | |
| metadata = card.metadata_from_config(hf_path) | |
| model_card = card.Card(model=model, metadata=metadata) | |
| if model_description: | |
| model_card.add(**{"Model description": model_description}) | |
| if intended_uses: | |
| model_card.add( | |
| **{"Model description/Intended uses & limitations": intended_uses} | |
| ) | |
| if metrics: | |
| metrics_table = _parse_metrics(metrics) | |
| model_card.add_metrics(**metrics_table) | |
| if authors: | |
| model_card.add(**{"Model Card Authors": authors}) | |
| if contact: | |
| model_card.add(**{"Model Card Contact": contact}) | |
| if citation: | |
| model_card.add(**{"Citation": citation}) | |
| if st.session_state.custom_sections: | |
| for key, val in st.session_state.custom_sections.items(): | |
| if not key: | |
| continue | |
| if key.startswith(PLOT_PREFIX): | |
| key = key[len(PLOT_PREFIX) :] # noqa | |
| model_card.add_plot(**{key: val}) | |
| else: | |
| model_card.add(**{key: val}) | |
| return model_card | |
| def _process_card_for_rendering(rendered: str) -> tuple[str, str]: | |
| idx = rendered[1:].index("\n---") + 1 | |
| metadata = rendered[3:idx] | |
| rendered = rendered[idx + 4 :] # noqa: E203 | |
| # below is a hack to display the images in streamlit | |
| # https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is | |
| # that streamlit does not display images in markdown, so we need to replace | |
| # them with html. However, we only want that in the rendered markdown, not | |
| # in the card that is produced for the hub | |
| def markdown_images(markdown): | |
| # example image markdown: | |
| #  | |
| images = re.findall( | |
| r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))', | |
| markdown, | |
| ) | |
| return images | |
| def img_to_bytes(img_path): | |
| img_bytes = Path(img_path).read_bytes() | |
| encoded = base64.b64encode(img_bytes).decode() | |
| return encoded | |
| def img_to_html(img_path, img_alt): | |
| img_format = img_path.split(".")[-1] | |
| img_html = ( | |
| f'<img src="data:image/{img_format.lower()};' | |
| f'base64,{img_to_bytes(img_path)}" ' | |
| f'alt="{img_alt}" ' | |
| 'style="max-width: 100%;">' | |
| ) | |
| return img_html | |
| def markdown_insert_images(markdown): | |
| images = markdown_images(markdown) | |
| for image in images: | |
| image_markdown = image[0] | |
| image_alt = image[1] | |
| image_path = image[2] | |
| markdown = markdown.replace( | |
| image_markdown, img_to_html(image_path, image_alt) | |
| ) | |
| return markdown | |
| rendered_with_img = markdown_insert_images(rendered) | |
| return metadata, rendered_with_img | |
| def display_model_card(model_card): | |
| if not model_card: | |
| return | |
| rendered = model_card.render() | |
| metadata, rendered = _process_card_for_rendering(rendered) | |
| # idx = rendered[1:].index("\n---") + 1 | |
| # metadata = rendered[3:idx] | |
| # rendered = rendered[idx + 4 :] # noqa: E203 | |
| # strip metadata | |
| with st.expander("show metadata"): | |
| st.text(metadata) | |
| st.markdown(rendered, unsafe_allow_html=True) | |
| def download_model_card(model_card): | |
| if model_card is not None: | |
| return model_card.render() | |
| return "" | |
| def add_custom_section(): | |
| # this is required to "refresh" these variables... | |
| global section_name, section_content | |
| section_name = st.session_state.key_section_name | |
| section_content = st.session_state.key_section_content | |
| if not section_name or not section_content: | |
| return | |
| st.session_state.custom_sections[section_name] = section_content | |
| def add_custom_plot(): | |
| # this is required to "refresh" these variables... | |
| global section_name, section_content | |
| plot_name = st.session_state.key_plot_name | |
| plot_file = st.session_state.key_plot_file | |
| if not plot_name or not plot_file: | |
| return | |
| # store plot in temp repo | |
| file_name = plot_file.name.replace(" ", "_") | |
| file_path = str(tmp_path / file_name) | |
| with open(file_path, "wb") as f: | |
| f.write(plot_file.getvalue()) | |
| st.session_state.custom_sections[str(PLOT_PREFIX + plot_name)] = file_path | |
| with st.sidebar: | |
| # This contains every element required to edit the model card | |
| model = None | |
| data = None | |
| section_name = None | |
| section_content = None | |
| st.title("Model Card Editor") | |
| model_file = st.file_uploader("Upload a model*", on_change=load_model) | |
| data_file = st.file_uploader( | |
| "Upload X data (csv)*", type=["csv"], on_change=load_data | |
| ) | |
| task = st.selectbox( | |
| label="Choose the task type*", | |
| options=[ | |
| "tabular-classification", | |
| "tabular-regression", | |
| "text-classification", | |
| "text-regression", | |
| ], | |
| on_change=init_repo, | |
| ) | |
| requirements = st.text_area( | |
| label="Requirements*", | |
| value=f"scikit-learn=={sklearn.__version__}\n", | |
| on_change=init_repo, | |
| ) | |
| if model_file is not None: | |
| model = load_model() | |
| if data_file is not None: | |
| data = load_data() | |
| if model is not None and data is not None: | |
| init_repo() | |
| model_card_repo = st.text_input( | |
| "Optional: HF repo to load model card from (e.g. 'gpt2'), " | |
| "leave empty to use default skops template", | |
| value="", | |
| ) | |
| # DEFAULT SKOPS SECTIONS | |
| if not model_card_repo: | |
| model_description = st.text_input("Model description", value=PLACEHOLDER) | |
| intended_uses = st.text_area( | |
| "Intended uses & limitations", height=2, value=PLACEHOLDER | |
| ) | |
| metrics = st.text_area("Metrics (e.g. 'accuracy = 0.95'), one metric per line") | |
| authors = st.text_area( | |
| "Authors", | |
| value="This model card is written by following authors:\n\n" + PLACEHOLDER, | |
| ) | |
| contact = st.text_area( | |
| "Contact", | |
| value="You can contact the model card authors through following channels:\n\n" | |
| + PLACEHOLDER, | |
| ) | |
| citation = st.text_area( | |
| "Citation", | |
| value="Below you can find information related to citation.\n\nBibTex:\n\n```\n" | |
| + PLACEHOLDER | |
| + "\n```", | |
| height=5, | |
| ) | |
| else: | |
| model_description = None | |
| intended_uses = None | |
| metrics = None | |
| authors = None | |
| contact = None | |
| citation = None | |
| # ADD A CUSTOM SECTIONS | |
| with st.form("custom-section", clear_on_submit=True): | |
| section_name = st.text_input( | |
| "Section name (use '/' for subsections, e.g. 'Model description/My new" | |
| " section')", | |
| key="key_section_name", | |
| ) | |
| section_content = st.text_area( | |
| "Content of the new section", key="key_section_content" | |
| ) | |
| submit_new_section = st.form_submit_button( | |
| "Create new section", on_click=add_custom_section | |
| ) | |
| # ADD A PLOT | |
| with st.form("custom-plots", clear_on_submit=True): | |
| plot_name = st.text_input( | |
| "Section name (use '/' for subsections, e.g. 'Model description/My new" | |
| " plot')", | |
| key="key_plot_name", | |
| ) | |
| plot_file = st.file_uploader("Upload a figure*", key="key_plot_file") | |
| submit_new_plot = st.form_submit_button("Add plot", on_click=add_custom_plot) | |
| for key in st.session_state.custom_sections: | |
| if not key: | |
| continue | |
| if key.startswith(PLOT_PREFIX): | |
| st.button( | |
| f"Remove plot '{key[len(PLOT_PREFIX):]}'", | |
| on_click=_remove_custom_section, | |
| args=(key,), | |
| ) | |
| else: | |
| st.button( | |
| f"Remove section '{key}'", on_click=_remove_custom_section, args=(key,) | |
| ) | |
| if st.session_state.custom_sections: | |
| st.button( | |
| f"Remove all ({len(st.session_state.custom_sections)}) custom elements", | |
| on_click=_clear_custom_section_cache, | |
| ) | |
| model_card = None | |
| if model is None: | |
| st.text("*add a model to render the model card*") | |
| if data is None: | |
| st.text("*add data to render the model card") | |
| if (model is not None) and (data is not None): | |
| model_card = _create_model_card() | |
| # this contains the rendered model card | |
| rendered = download_model_card(model_card) | |
| if rendered: | |
| st.download_button(label="Download model card (markdown format)", data=rendered) | |
| display_model_card(model_card) | |