Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import reprlib | |
| from pathlib import Path | |
| from tempfile import mkdtemp | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| from skops import card | |
| from skops.card._model_card import PlotSection, split_subsection_names | |
| from utils import iterate_key_section_content, process_card_for_rendering | |
| from tasks import AddSectionTask, AddFigureTask, DeleteSectionTask, TaskState, UpdateFigureTask, UpdateSectionTask | |
| arepr = reprlib.Repr() | |
| arepr.maxstring = 24 | |
| tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files | |
| hf_path = Path(mkdtemp(prefix="skops-")) # hf repo | |
| def load_model_card_from_repo(repo_id: str) -> card.Card: | |
| print("downloading model card") | |
| path = hf_hub_download(repo_id, "README.md") | |
| model_card = card.parse_modelcard(path) | |
| return model_card | |
| def _update_model_card( | |
| model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool, | |
| ) -> None: | |
| # This is a very roundabout way to update the model card but it's necessary | |
| # because of how streamlit handles session state. Basically, there have to | |
| # be "key" arguments, which have to be retrieved from the session_state, as | |
| # they are up-to-date. Just getting the Python variables is not enough, as | |
| # they can be out of date. | |
| # key names must match with those used in form | |
| new_title = st.session_state[f"{key}.title"] | |
| new_content = st.session_state[f"{key}.content"] | |
| # determine if title is the same | |
| old_title_split = split_subsection_names(section_name) | |
| new_title_split = old_title_split[:-1] + [new_title] | |
| is_title_same = old_title_split == new_title_split | |
| # determine if content is the same | |
| if is_fig: | |
| if isinstance(new_content, PlotSection): | |
| is_content_same = content == new_content | |
| else: | |
| is_content_same = not bool(new_content) | |
| else: | |
| is_content_same = content == new_content | |
| if is_title_same and is_content_same: | |
| return | |
| if is_fig: | |
| fpath = None | |
| if new_content: # new figure uploaded | |
| fname = new_content.name.replace(" ", "_") | |
| fpath = tmp_path / fname | |
| task = UpdateFigureTask( | |
| model_card, | |
| key=key, | |
| old_name=section_name, | |
| new_name=new_title, | |
| data=new_content, | |
| path=fpath, | |
| ) | |
| else: | |
| task = UpdateSectionTask( | |
| model_card, | |
| key=key, | |
| old_name=section_name, | |
| new_name=new_title, | |
| old_content=content, | |
| new_content=new_content, | |
| ) | |
| st.session_state.task_state.add(task) | |
| def _add_section(model_card: card.Card, key: str) -> None: | |
| section_name = f"{key}/Untitled" | |
| task = AddSectionTask(model_card, title=section_name, content="[More Information Needed]") | |
| st.session_state.task_state.add(task) | |
| def _add_figure(model_card: card.Card, key: str) -> None: | |
| section_name = f"{key}/Untitled" | |
| task = AddFigureTask(model_card, title=section_name, content="cat.png") | |
| st.session_state.task_state.add(task) | |
| def _delete_section(model_card: card.Card, key: str) -> None: | |
| task = DeleteSectionTask(model_card, key=key) | |
| st.session_state.task_state.add(task) | |
| def _add_section_form( | |
| model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
| ) -> None: | |
| with st.form(key, clear_on_submit=False): | |
| st.header(section_name) | |
| # setting the 'key' argument below to update the session_state | |
| st.text_input("Section name", value=old_title, key=f"{key}.title") | |
| st.text_area("Content", value=content, key=f"{key}.content") | |
| is_fig = False | |
| st.form_submit_button( | |
| "Update", | |
| on_click=_update_model_card, | |
| args=(model_card, key, section_name, content, is_fig), | |
| ) | |
| def _add_fig_form( | |
| model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
| ) -> None: | |
| with st.form(key, clear_on_submit=False): | |
| st.header(section_name) | |
| # setting the 'key' argument below to update the session_state | |
| st.text_input("Section name", value=old_title, key=f"{key}.title") | |
| st.file_uploader("Upload image", key=f"{key}.content") | |
| is_fig = True | |
| st.form_submit_button( | |
| "Update", | |
| on_click=_update_model_card, | |
| args=(model_card, key, section_name, content, is_fig), | |
| ) | |
| def create_form_from_section( | |
| model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool = False | |
| ) -> None: | |
| split_sections = split_subsection_names(section_name) | |
| old_title = split_sections[-1] | |
| if is_fig: | |
| _add_fig_form( | |
| model_card=model_card, | |
| key=key, | |
| section_name=section_name, | |
| old_title=old_title, | |
| content=content, | |
| ) | |
| else: | |
| _add_section_form( | |
| model_card=model_card, | |
| key=key, | |
| section_name=section_name, | |
| old_title=old_title, | |
| content=content, | |
| ) | |
| col_0, col_1, col_2 = st.columns([4, 2, 2]) | |
| with col_0: | |
| st.button( | |
| f"delete '{arepr.repr(old_title)}'", | |
| on_click=_delete_section, | |
| args=(model_card, key), | |
| key=f"{key}.delete", | |
| ) | |
| with col_1: | |
| st.button( | |
| "add section below", | |
| on_click=_add_section, | |
| args=(model_card, key), | |
| key=f"{key}.add", | |
| ) | |
| with col_2: | |
| st.button( | |
| "add figure below", | |
| on_click=_add_figure, | |
| args=(model_card, key), | |
| key=f"{key}.fig", | |
| ) | |
| def display_sections(model_card: card.Card) -> None: | |
| for key, section_name, content, is_fig in iterate_key_section_content(model_card._data): | |
| create_form_from_section(model_card, key, section_name, content, is_fig) | |
| def display_model_card(model_card: card.Card) -> None: | |
| rendered = model_card.render() | |
| metadata, rendered = process_card_for_rendering(rendered) | |
| # strip metadata | |
| with st.expander("show metadata"): | |
| st.text(metadata) | |
| st.markdown(rendered, unsafe_allow_html=True) | |
| def reset_model_card() -> None: | |
| if "task_state" not in st.session_state: | |
| return | |
| if "model_card" not in st.session_state: | |
| del st.session_state["model_card"] | |
| while st.session_state.task_state.done_list: | |
| st.session_state.task_state.undo() | |
| def delete_model_card() -> None: | |
| if "model_card" in st.session_state: | |
| del st.session_state["model_card"] | |
| if "task_state" in st.session_state: | |
| st.session_state.task_state.reset() | |
| def undo_last(): | |
| st.session_state.task_state.undo() | |
| display_model_card(st.session_state.model_card) | |
| def redo_last(): | |
| st.session_state.task_state.redo() | |
| display_model_card(st.session_state.model_card) | |
| def add_download_model_card_button(): | |
| model_card = st.session_state.get("model_card") | |
| download_disabled = not bool(model_card) | |
| data = model_card.render() | |
| st.download_button( | |
| "Save (md)", data=data, disabled=download_disabled | |
| ) | |
| def edit_input_form(): | |
| if "task_state" not in st.session_state: | |
| st.session_state.task_state = TaskState() | |
| with st.sidebar: | |
| col_0, col_1, col_2, col_3, col_4 = st.columns([1.6, 1.5, 1.2, 2, 1.5]) | |
| undo_disabled = not bool(st.session_state.task_state.done_list) | |
| redo_disabled = not bool(st.session_state.task_state.undone_list) | |
| with col_0: | |
| name = f"UNDO ({len(st.session_state.task_state.done_list)})" | |
| st.button(name, on_click=undo_last, disabled=undo_disabled) | |
| with col_1: | |
| name = f"REDO ({len(st.session_state.task_state.undone_list)})" | |
| st.button(name, on_click=redo_last, disabled=redo_disabled) | |
| with col_2: | |
| st.button("Reset", on_click=reset_model_card) | |
| with col_3: | |
| add_download_model_card_button() | |
| with col_4: | |
| st.button("Delete", on_click=delete_model_card) | |
| if "model_card" in st.session_state: | |
| display_sections(st.session_state.model_card) | |
| if "model_card" in st.session_state: | |
| display_model_card(st.session_state.model_card) | |