Spaces:
Runtime error
Runtime error
| """The editing page of the app | |
| This is the meat of the application. On the sidebar, the content of the model | |
| card is displayed in the form of editable fields. On the right side, the | |
| rendered model card is shown. | |
| In the side bar, users can: | |
| - edit the title and content of existing sections | |
| - delete sections | |
| - add new sections below the current section | |
| - add new figures below the current section | |
| Moreover, each action results in a "task" that is tracked in the task state. A | |
| task has a "do" and an "undo" method. This allows us to provide "undo" and | |
| "redo" features to the app, making it easier for users to experiment and deal | |
| with errors. The "reset" button undoes all the tasks, leading back to the | |
| initial model card. | |
| When the user is finished, there is a "save" button that downloads the model | |
| card. They can also click "delete" to start over again, leading them to the | |
| start page. | |
| """ | |
| from __future__ import annotations | |
| import reprlib | |
| from pathlib import Path | |
| from tempfile import mkdtemp | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from skops import card | |
| from skops.card._model_card import PlotSection, split_subsection_names | |
| from utils import ( | |
| get_rendered_model_card, | |
| iterate_key_section_content, | |
| process_card_for_rendering, | |
| ) | |
| from tasks import ( | |
| AddMetricsTask, | |
| AddSectionTask, | |
| AddFigureTask, | |
| DeleteSectionTask, | |
| TaskState, | |
| UpdateFigureTask, | |
| UpdateSectionTask, | |
| ) | |
| arepr = reprlib.Repr() | |
| arepr.maxstring = 24 | |
| tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files | |
| def load_model_card_from_repo(repo_id: str) -> card.Card: | |
| print("downloading model card") | |
| path = hf_hub_download(repo_id, "README.md") | |
| model_card = card.parse_modelcard(path) | |
| return model_card | |
| def _update_model_card( | |
| model_card: card.Card, | |
| key: str, | |
| section_name: str, | |
| content: str, | |
| is_fig: bool, | |
| ) -> None: | |
| # This is a very roundabout way to update the model card but it's necessary | |
| # because of how streamlit handles session state. Basically, there have to | |
| # be "key" arguments, which have to be retrieved from the session_state, as | |
| # they are up-to-date. Just getting the Python variables is not enough, as | |
| # they can be out of date. | |
| # key names must match with those used in form | |
| new_title = st.session_state[f"{key}.title"] | |
| new_content = st.session_state[f"{key}.content"] | |
| # determine if title is the same | |
| old_title_split = split_subsection_names(section_name) | |
| new_title_split = old_title_split[:-1] + [new_title] | |
| is_title_same = old_title_split == new_title_split | |
| # determine if content is the same | |
| if is_fig: | |
| if isinstance(new_content, PlotSection): | |
| is_content_same = content == new_content | |
| else: | |
| is_content_same = not bool(new_content) | |
| else: | |
| is_content_same = content == new_content | |
| if is_title_same and is_content_same: | |
| return | |
| if is_fig: | |
| old_path, fpath = None, None | |
| if new_content: # new figure uploaded | |
| fname = new_content.name.replace(" ", "_") | |
| fpath = st.session_state.hf_path / fname | |
| old_path = fpath.parent / model_card.select(key).content.path | |
| task = UpdateFigureTask( | |
| model_card, | |
| key=key, | |
| old_name=section_name, | |
| new_name=new_title, | |
| data=new_content, | |
| new_path=fpath, | |
| old_path=old_path, | |
| ) | |
| else: | |
| task = UpdateSectionTask( | |
| model_card, | |
| key=key, | |
| old_name=section_name, | |
| new_name=new_title, | |
| old_content=content, | |
| new_content=new_content, | |
| ) | |
| st.session_state.task_state.add(task) | |
| def _add_section(model_card: card.Card, key: str) -> None: | |
| section_name = f"{key}/Untitled" | |
| task = AddSectionTask( | |
| model_card, title=section_name, content="[More Information Needed]" | |
| ) | |
| st.session_state.task_state.add(task) | |
| def _add_figure(model_card: card.Card, key: str) -> None: | |
| section_name = f"{key}/Untitled" | |
| hf_path = st.session_state.hf_path | |
| task = AddFigureTask( | |
| model_card, path=hf_path, title=section_name, content="cat.png" | |
| ) | |
| st.session_state.task_state.add(task) | |
| def _delete_section(model_card: card.Card, key: str, path: Path) -> None: | |
| task = DeleteSectionTask(model_card, key=key, path=path) | |
| st.session_state.task_state.add(task) | |
| def _add_section_form( | |
| model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
| ) -> None: | |
| with st.form(key, clear_on_submit=False): | |
| st.header(section_name) | |
| # setting the 'key' argument below to update the session_state | |
| st.text_input("Section name", value=old_title, key=f"{key}.title") | |
| st.text_area("Content", value=content, key=f"{key}.content") | |
| is_fig = False | |
| st.form_submit_button( | |
| "Update", | |
| on_click=_update_model_card, | |
| args=(model_card, key, section_name, content, is_fig), | |
| ) | |
| def _add_fig_form( | |
| model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
| ) -> None: | |
| with st.form(key, clear_on_submit=False): | |
| st.header(section_name) | |
| # setting the 'key' argument below to update the session_state | |
| st.text_input("Section name", value=old_title, key=f"{key}.title") | |
| st.file_uploader("Upload image", key=f"{key}.content") | |
| is_fig = True | |
| st.form_submit_button( | |
| "Update", | |
| on_click=_update_model_card, | |
| args=(model_card, key, section_name, content, is_fig), | |
| ) | |
| def create_form_from_section( | |
| model_card: card.Card, | |
| key: str, | |
| section_name: str, | |
| content: str, | |
| is_fig: bool = False, | |
| ) -> None: | |
| split_sections = split_subsection_names(section_name) | |
| old_title = split_sections[-1] | |
| if is_fig: | |
| _add_fig_form( | |
| model_card=model_card, | |
| key=key, | |
| section_name=section_name, | |
| old_title=old_title, | |
| content=content, | |
| ) | |
| else: | |
| _add_section_form( | |
| model_card=model_card, | |
| key=key, | |
| section_name=section_name, | |
| old_title=old_title, | |
| content=content, | |
| ) | |
| col_0, col_1, col_2 = st.columns([4, 2, 2]) | |
| with col_0: | |
| path = st.session_state.hf_path / content.path if is_fig else None | |
| st.button( | |
| f"Delete '{arepr.repr(old_title)}'", | |
| on_click=_delete_section, | |
| args=(model_card, key, path), | |
| key=f"{key}.delete", | |
| help="Delete this section, including all its subsections", | |
| ) | |
| with col_1: | |
| st.button( | |
| "add section below", | |
| on_click=_add_section, | |
| args=(model_card, key), | |
| key=f"{key}.add", | |
| help="Add a new subsection below this section", | |
| ) | |
| with col_2: | |
| st.button( | |
| "add figure below", | |
| on_click=_add_figure, | |
| args=(model_card, key), | |
| key=f"{key}.fig", | |
| help="Add a new figure below this section", | |
| ) | |
| def display_sections(model_card: card.Card) -> None: | |
| for section_info in iterate_key_section_content(model_card._data): | |
| create_form_from_section( | |
| model_card, | |
| key=section_info.return_key, | |
| section_name=section_info.title, | |
| content=section_info.content, | |
| is_fig=section_info.is_fig, | |
| ) | |
| def display_toc(model_card: card.Card) -> None: | |
| elements = [] | |
| for section_info in iterate_key_section_content(model_card._data): | |
| title, level = section_info.title, section_info.level | |
| section_name = split_subsection_names(title)[-1] | |
| elements.append(" " * level + "- " + section_name) | |
| st.markdown("\n".join(elements)) | |
| def display_model_card(model_card: card.Card) -> None: | |
| rendered = model_card.render() | |
| metadata, rendered = process_card_for_rendering(rendered) | |
| # strip metadata | |
| with st.expander("show metadata"): | |
| st.text(metadata) | |
| with st.expander("Table of Contents"): | |
| display_toc(model_card) | |
| st.markdown(rendered, unsafe_allow_html=True) | |
| def reset_model_card() -> None: | |
| if "task_state" not in st.session_state: | |
| return | |
| if "model_card" not in st.session_state: | |
| del st.session_state["model_card"] | |
| while st.session_state.task_state.done_list: | |
| st.session_state.task_state.undo() | |
| def delete_model_card() -> None: | |
| st.session_state.screen.state = "start" | |
| if "model_card" in st.session_state: | |
| del st.session_state["model_card"] | |
| if "task_state" in st.session_state: | |
| st.session_state.task_state.reset() | |
| def undo_last(): | |
| st.session_state.task_state.undo() | |
| display_model_card(st.session_state.model_card) | |
| def redo_last(): | |
| st.session_state.task_state.redo() | |
| display_model_card(st.session_state.model_card) | |
| def add_download_model_card_button(): | |
| model_card = st.session_state.model_card | |
| data = get_rendered_model_card(model_card, hf_path=str(st.session_state.hf_path)) | |
| tip = "Download the generated model card as markdown file" | |
| st.download_button( | |
| "Save (md)", | |
| data=data, | |
| help=tip, | |
| file_name="README.md", | |
| ) | |
| def add_create_repo_button(): | |
| def fn(): | |
| st.session_state.screen.state = "create_repo" | |
| button_disabled = not bool(st.session_state.get("model_card")) | |
| st.button( | |
| "Create Repo", | |
| help="Create a model repository on Hugging Face Hub", | |
| on_click=fn, | |
| disabled=button_disabled, | |
| ) | |
| def display_edit_buttons(): | |
| # first row: undo + redo + reset | |
| col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2]) | |
| undo_disabled = not bool(st.session_state.task_state.done_list) | |
| redo_disabled = not bool(st.session_state.task_state.undone_list) | |
| with col_0: | |
| name = f"UNDO ({len(st.session_state.task_state.done_list)})" | |
| tip = "Undo the last edit" | |
| st.button(name, on_click=undo_last, disabled=undo_disabled, help=tip) | |
| with col_1: | |
| name = f"REDO ({len(st.session_state.task_state.undone_list)})" | |
| tip = "Redo the last undone edit" | |
| st.button(name, on_click=redo_last, disabled=redo_disabled, help=tip) | |
| with col_2: | |
| tip = "Undo all edits" | |
| st.button("Reset", on_click=reset_model_card, help=tip) | |
| # second row: download + create repo + delete | |
| col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2]) | |
| with col_0: | |
| add_download_model_card_button() | |
| with col_1: | |
| add_create_repo_button() | |
| with col_2: | |
| tip = "Start over from scratch (lose all progress)" | |
| st.button("Delete", on_click=delete_model_card, help=tip) | |
| def _update_model_diagram(): | |
| val = st.session_state.get("special_model_diagram", True) | |
| model_card = st.session_state.model_card | |
| model_card.model_diagram = val | |
| # TODO: this may no longer be necesssary once this issue is solved: | |
| # https://github.com/skops-dev/skops/issues/292 | |
| if val: | |
| model_card.add_model_plot() | |
| else: | |
| model_card.delete("Model description/Training Procedure/Model Plot") | |
| def _parse_metrics(metrics: str) -> dict[str, str | float]: | |
| # parse metrics from text area, one per line, into a dict | |
| metrics_table = {} | |
| for line in metrics.splitlines(): | |
| line = line.strip() | |
| val: str | float | |
| name, _, val = line.partition("=") | |
| try: | |
| # try to coerce to float but don't error if it fails | |
| val = float(val.strip()) | |
| except ValueError: | |
| pass | |
| metrics_table[name.strip()] = val | |
| return metrics_table | |
| def _update_metrics(): | |
| metrics = st.session_state.get("special_metrics_text", {}) | |
| model_card = st.session_state.model_card | |
| metrics_table = _parse_metrics(metrics) | |
| # check if any change | |
| if metrics_table == model_card._metrics: | |
| return | |
| task = AddMetricsTask(model_card, metrics_table) | |
| st.session_state.task_state.add(task) | |
| def display_skops_special_fields(): | |
| st.checkbox( | |
| "Show model diagram", | |
| value=True, | |
| on_change=_update_model_diagram, | |
| key="special_model_diagram", | |
| ) | |
| with st.expander("Add metrics"): | |
| with st.form("special_metrics", clear_on_submit=False): | |
| st.text_area( | |
| "Add one metric per line, e.g. 'accuracy = 0.9'", | |
| key="special_metrics_text", | |
| ) | |
| st.form_submit_button( | |
| "Update", | |
| on_click=_update_metrics, | |
| ) | |
| def edit_input_form(): | |
| if "task_state" not in st.session_state: | |
| st.session_state.task_state = TaskState() | |
| with st.sidebar: | |
| # TOP ROW BUTTONS | |
| display_edit_buttons() | |
| # SHOW SPECIAL FIELDS IF SKOPS TEMPLATE WAS USED | |
| if st.session_state.get("model_card_type", "") == "skops": | |
| display_skops_special_fields() | |
| # SHOW EDITABLE SECTIONS | |
| if "model_card" in st.session_state: | |
| display_sections(st.session_state.model_card) | |
| if "model_card" in st.session_state: | |
| display_model_card(st.session_state.model_card) | |