Spaces:
Sleeping
Sleeping
| """The main page for the Trackio UI.""" | |
| import os | |
| import re | |
| import secrets | |
| import shutil | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| try: | |
| import trackio.utils as utils | |
| from trackio.file_storage import FileStorage | |
| from trackio.media import TrackioImage, TrackioVideo | |
| from trackio.sqlite_storage import SQLiteStorage | |
| from trackio.table import Table | |
| from trackio.typehints import LogEntry, UploadEntry | |
| from trackio.ui import fns | |
| from trackio.ui.helpers.run_selection import RunSelection | |
| from trackio.ui.run_detail import run_detail_page | |
| from trackio.ui.runs import run_page | |
| except ImportError: | |
| import utils | |
| from file_storage import FileStorage | |
| from media import TrackioImage, TrackioVideo | |
| from sqlite_storage import SQLiteStorage | |
| from table import Table | |
| from typehints import LogEntry, UploadEntry | |
| from ui import fns | |
| from ui.helpers.run_selection import RunSelection | |
| from ui.run_detail import run_detail_page | |
| from ui.runs import run_page | |
| INSTRUCTIONS_SPACES = """ | |
| ## Start logging with Trackio 🤗 | |
| To start logging to this Trackio dashboard, first make sure you have the Trackio library installed. You can do this by running: | |
| ```bash | |
| pip install trackio | |
| ``` | |
| Then, start logging to this Trackio dashboard by passing in the `space_id` to `trackio.init()`: | |
| ```python | |
| import trackio | |
| trackio.init(project="my-project", space_id="{}") | |
| ``` | |
| Then call `trackio.log()` to log metrics. | |
| ```python | |
| for i in range(10): | |
| trackio.log({{"loss": 1/(i+1)}}) | |
| ``` | |
| Finally, call `trackio.finish()` to finish the run. | |
| ```python | |
| trackio.finish() | |
| ``` | |
| """ | |
| INSTRUCTIONS_LOCAL = """ | |
| ## Start logging with Trackio 🤗 | |
| You can create a new project by calling `trackio.init()`: | |
| ```python | |
| import trackio | |
| trackio.init(project="my-project") | |
| ``` | |
| Then call `trackio.log()` to log metrics. | |
| ```python | |
| for i in range(10): | |
| trackio.log({"loss": 1/(i+1)}) | |
| ``` | |
| Finally, call `trackio.finish()` to finish the run. | |
| ```python | |
| trackio.finish() | |
| ``` | |
| Read the [Trackio documentation](https://huggingface.co/docs/trackio/en/index) for more examples. | |
| """ | |
| def get_runs(project) -> list[str]: | |
| if not project: | |
| return [] | |
| return SQLiteStorage.get_runs(project) | |
| def get_available_metrics(project: str, runs: list[str]) -> list[str]: | |
| """Get all available metrics across all runs for x-axis selection.""" | |
| if not project or not runs: | |
| return ["step", "time"] | |
| all_metrics = set() | |
| for run in runs: | |
| metrics = SQLiteStorage.get_logs(project, run) | |
| if metrics: | |
| df = pd.DataFrame(metrics) | |
| numeric_cols = df.select_dtypes(include="number").columns | |
| numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] | |
| all_metrics.update(numeric_cols) | |
| all_metrics.add("step") | |
| all_metrics.add("time") | |
| sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics)) | |
| result = ["step", "time"] | |
| for metric in sorted_metrics: | |
| if metric not in result: | |
| result.append(metric) | |
| return result | |
| class MediaData: | |
| caption: str | None | |
| file_path: str | |
| def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]: | |
| media_by_key: dict[str, list[MediaData]] = {} | |
| logs = sorted(logs, key=lambda x: x.get("step", 0)) | |
| for log in logs: | |
| for key, value in log.items(): | |
| if isinstance(value, dict): | |
| type = value.get("_type") | |
| if type == TrackioImage.TYPE or type == TrackioVideo.TYPE: | |
| if key not in media_by_key: | |
| media_by_key[key] = [] | |
| try: | |
| media_data = MediaData( | |
| file_path=utils.MEDIA_DIR / value.get("file_path"), | |
| caption=value.get("caption"), | |
| ) | |
| media_by_key[key].append(media_data) | |
| except Exception as e: | |
| print(f"Media currently unavailable: {key}: {e}") | |
| return media_by_key | |
| def load_run_data( | |
| project: str | None, | |
| run: str | None, | |
| smoothing_granularity: int, | |
| x_axis: str, | |
| log_scale: bool = False, | |
| ) -> tuple[pd.DataFrame, dict]: | |
| if not project or not run: | |
| return None, None | |
| logs = SQLiteStorage.get_logs(project, run) | |
| if not logs: | |
| return None, None | |
| media = extract_media(logs) | |
| df = pd.DataFrame(logs) | |
| if "step" not in df.columns: | |
| df["step"] = range(len(df)) | |
| if x_axis == "time" and "timestamp" in df.columns: | |
| df["timestamp"] = pd.to_datetime(df["timestamp"]) | |
| first_timestamp = df["timestamp"].min() | |
| df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds() | |
| x_column = "time" | |
| elif x_axis == "step": | |
| x_column = "step" | |
| else: | |
| x_column = x_axis | |
| if log_scale and x_column in df.columns: | |
| x_vals = df[x_column] | |
| if (x_vals <= 0).any(): | |
| df[x_column] = np.log10(np.maximum(x_vals, 0) + 1) | |
| else: | |
| df[x_column] = np.log10(x_vals) | |
| if smoothing_granularity > 0: | |
| numeric_cols = df.select_dtypes(include="number").columns | |
| numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] | |
| df_original = df.copy() | |
| df_original["run"] = run | |
| df_original["data_type"] = "original" | |
| df_smoothed = df.copy() | |
| window_size = max(3, min(smoothing_granularity, len(df))) | |
| df_smoothed[numeric_cols] = ( | |
| df_smoothed[numeric_cols] | |
| .rolling(window=window_size, center=True, min_periods=1) | |
| .mean() | |
| ) | |
| df_smoothed["run"] = f"{run}_smoothed" | |
| df_smoothed["data_type"] = "smoothed" | |
| combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) | |
| combined_df["x_axis"] = x_column | |
| return combined_df, media | |
| else: | |
| df["run"] = run | |
| df["data_type"] = "original" | |
| df["x_axis"] = x_column | |
| return df, media | |
| def refresh_runs( | |
| project: str | None, | |
| filter_text: str | None, | |
| selection: RunSelection, | |
| selected_runs_from_url: list[str] | None = None, | |
| ): | |
| if project is None: | |
| runs: list[str] = [] | |
| else: | |
| runs = get_runs(project) | |
| if filter_text: | |
| runs = [r for r in runs if filter_text in r] | |
| preferred = None | |
| if selected_runs_from_url: | |
| preferred = [r for r in runs if r in selected_runs_from_url] | |
| did_change = selection.update_choices(runs, preferred) | |
| return ( | |
| fns.run_checkbox_update(selection) if did_change else gr.CheckboxGroup(), | |
| gr.Textbox(label=f"Runs ({len(runs)})"), | |
| selection, | |
| ) | |
| def generate_embed(project: str, metrics: str, selection: RunSelection) -> str: | |
| return utils.generate_embed_code(project, metrics, selection.selected) | |
| def update_x_axis_choices(project, selection): | |
| """Update x-axis dropdown choices based on available metrics.""" | |
| runs = selection.selected | |
| available_metrics = get_available_metrics(project, runs) | |
| return gr.Dropdown( | |
| label="X-axis", | |
| choices=available_metrics, | |
| value="step", | |
| ) | |
| def toggle_timer(cb_value): | |
| if cb_value: | |
| return gr.Timer(active=True) | |
| else: | |
| return gr.Timer(active=False) | |
| def upload_db_to_space( | |
| project: str, uploaded_db: gr.FileData, hf_token: str | None | |
| ) -> None: | |
| """ | |
| Uploads the database of a local Trackio project to a Hugging Face Space. | |
| """ | |
| fns.check_hf_token_has_write_access(hf_token) | |
| db_project_path = SQLiteStorage.get_project_db_path(project) | |
| if os.path.exists(db_project_path): | |
| raise gr.Error( | |
| f"Trackio database file already exists for project {project}, cannot overwrite." | |
| ) | |
| os.makedirs(os.path.dirname(db_project_path), exist_ok=True) | |
| shutil.copy(uploaded_db["path"], db_project_path) | |
| def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None: | |
| """ | |
| Uploads media files to a Trackio dashboard. Each entry in the list is a tuple of the project, run, and media file to be uploaded. | |
| """ | |
| fns.check_hf_token_has_write_access(hf_token) | |
| for upload in uploads: | |
| media_path = FileStorage.init_project_media_path( | |
| upload["project"], upload["run"], upload["step"] | |
| ) | |
| shutil.copy(upload["uploaded_file"]["path"], media_path) | |
| def log( | |
| project: str, | |
| run: str, | |
| metrics: dict[str, Any], | |
| step: int | None, | |
| hf_token: str | None, | |
| ) -> None: | |
| """ | |
| Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but | |
| is kept for backwards compatibility for users who are connecting to a newer version of | |
| a Trackio Spaces dashboard with an older version of Trackio installed locally. | |
| """ | |
| fns.check_hf_token_has_write_access(hf_token) | |
| SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step) | |
| def bulk_log( | |
| logs: list[LogEntry], | |
| hf_token: str | None, | |
| ) -> None: | |
| """ | |
| Logs a list of metrics to a Trackio dashboard. Each entry in the list is a dictionary of the project, run, a dictionary of metrics, and optionally, a step and config. | |
| """ | |
| fns.check_hf_token_has_write_access(hf_token) | |
| logs_by_run = {} | |
| for log_entry in logs: | |
| key = (log_entry["project"], log_entry["run"]) | |
| if key not in logs_by_run: | |
| logs_by_run[key] = {"metrics": [], "steps": [], "config": None} | |
| logs_by_run[key]["metrics"].append(log_entry["metrics"]) | |
| logs_by_run[key]["steps"].append(log_entry.get("step")) | |
| if log_entry.get("config") and logs_by_run[key]["config"] is None: | |
| logs_by_run[key]["config"] = log_entry["config"] | |
| for (project, run), data in logs_by_run.items(): | |
| SQLiteStorage.bulk_log( | |
| project=project, | |
| run=run, | |
| metrics_list=data["metrics"], | |
| steps=data["steps"], | |
| config=data["config"], | |
| ) | |
| def get_metric_values( | |
| project: str, | |
| run: str, | |
| metric_name: str, | |
| ) -> list[dict]: | |
| """ | |
| Get all values for a specific metric in a project/run. | |
| Returns a list of dictionaries with timestamp, step, and value. | |
| """ | |
| return SQLiteStorage.get_metric_values(project, run, metric_name) | |
| def get_runs_for_project( | |
| project: str, | |
| ) -> list[str]: | |
| """ | |
| Get all runs for a given project. | |
| Returns a list of run names. | |
| """ | |
| return SQLiteStorage.get_runs(project) | |
| def get_metrics_for_run( | |
| project: str, | |
| run: str, | |
| ) -> list[str]: | |
| """ | |
| Get all metrics for a given project and run. | |
| Returns a list of metric names. | |
| """ | |
| return SQLiteStorage.get_all_metrics_for_run(project, run) | |
| def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]: | |
| """ | |
| Filter metrics using regex pattern. | |
| Args: | |
| metrics: List of metric names to filter | |
| filter_pattern: Regex pattern to match against metric names | |
| Returns: | |
| List of metric names that match the pattern | |
| """ | |
| if not filter_pattern.strip(): | |
| return metrics | |
| try: | |
| pattern = re.compile(filter_pattern, re.IGNORECASE) | |
| return [metric for metric in metrics if pattern.search(metric)] | |
| except re.error: | |
| return [ | |
| metric for metric in metrics if filter_pattern.lower() in metric.lower() | |
| ] | |
| def get_all_projects() -> list[str]: | |
| """ | |
| Get all project names. | |
| Returns a list of project names. | |
| """ | |
| return SQLiteStorage.get_projects() | |
| def get_project_summary(project: str) -> dict: | |
| """ | |
| Get a summary of a project including number of runs and recent activity. | |
| Args: | |
| project: Project name | |
| Returns: | |
| Dictionary with project summary information | |
| """ | |
| runs = SQLiteStorage.get_runs(project) | |
| if not runs: | |
| return {"project": project, "num_runs": 0, "runs": [], "last_activity": None} | |
| last_steps = SQLiteStorage.get_max_steps_for_runs(project) | |
| return { | |
| "project": project, | |
| "num_runs": len(runs), | |
| "runs": runs, | |
| "last_activity": max(last_steps.values()) if last_steps else None, | |
| } | |
| def get_run_summary(project: str, run: str) -> dict: | |
| """ | |
| Get a summary of a specific run including metrics and configuration. | |
| Args: | |
| project: Project name | |
| run: Run name | |
| Returns: | |
| Dictionary with run summary information | |
| """ | |
| logs = SQLiteStorage.get_logs(project, run) | |
| metrics = SQLiteStorage.get_all_metrics_for_run(project, run) | |
| if not logs: | |
| return { | |
| "project": project, | |
| "run": run, | |
| "num_logs": 0, | |
| "metrics": [], | |
| "config": None, | |
| "last_step": None, | |
| } | |
| df = pd.DataFrame(logs) | |
| config = logs[0].get("config") if logs else None | |
| last_step = df["step"].max() if "step" in df.columns else len(logs) - 1 | |
| return { | |
| "project": project, | |
| "run": run, | |
| "num_logs": len(logs), | |
| "metrics": metrics, | |
| "config": config, | |
| "last_step": last_step, | |
| } | |
| def configure(request: gr.Request): | |
| sidebar_param = request.query_params.get("sidebar") | |
| match sidebar_param: | |
| case "collapsed": | |
| sidebar = gr.Sidebar(open=False, visible=True) | |
| case "hidden": | |
| sidebar = gr.Sidebar(open=False, visible=False) | |
| case _: | |
| sidebar = gr.Sidebar(open=True, visible=True) | |
| metrics_param = request.query_params.get("metrics", "") | |
| runs_param = request.query_params.get("runs", "") | |
| selected_runs = runs_param.split(",") if runs_param else [] | |
| navbar_param = request.query_params.get("navbar") | |
| match navbar_param: | |
| case "hidden": | |
| navbar = gr.Navbar(visible=False) | |
| case _: | |
| navbar = gr.Navbar(visible=True) | |
| return [], sidebar, metrics_param, selected_runs, navbar | |
| def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]): | |
| with gr.Accordion(label="media"): | |
| with gr.Group(elem_classes=("media-group")): | |
| for run, media_by_key in media_by_run.items(): | |
| with gr.Tab(label=run, elem_classes=("media-tab")): | |
| for key, media_item in media_by_key.items(): | |
| gr.Gallery( | |
| [(item.file_path, item.caption) for item in media_item], | |
| label=key, | |
| columns=6, | |
| elem_classes=("media-gallery"), | |
| ) | |
| css = """ | |
| #run-cb .wrap { gap: 2px; } | |
| #run-cb .wrap label { | |
| line-height: 1; | |
| padding: 6px; | |
| } | |
| .logo-light { display: block; } | |
| .logo-dark { display: none; } | |
| .dark .logo-light { display: none; } | |
| .dark .logo-dark { display: block; } | |
| .dark .caption-label { color: white; } | |
| .info-container { | |
| position: relative; | |
| display: inline; | |
| } | |
| .info-checkbox { | |
| position: absolute; | |
| opacity: 0; | |
| pointer-events: none; | |
| } | |
| .info-icon { | |
| border-bottom: 1px dotted; | |
| cursor: pointer; | |
| user-select: none; | |
| color: var(--color-accent); | |
| } | |
| .info-expandable { | |
| display: none; | |
| opacity: 0; | |
| transition: opacity 0.2s ease-in-out; | |
| } | |
| .info-checkbox:checked ~ .info-expandable { | |
| display: inline; | |
| opacity: 1; | |
| } | |
| .info-icon:hover { opacity: 0.8; } | |
| .accent-link { font-weight: bold; } | |
| .media-gallery .fixed-height { min-height: 275px; } | |
| .media-group, .media-group > div { background: none; } | |
| .media-group .tabs { padding: 0.5em; } | |
| .media-tab { max-height: 500px; overflow-y: scroll; } | |
| """ | |
| javascript = """ | |
| <script> | |
| function setCookie(name, value, days) { | |
| var expires = ""; | |
| if (days) { | |
| var date = new Date(); | |
| date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000)); | |
| expires = "; expires=" + date.toUTCString(); | |
| } | |
| document.cookie = name + "=" + (value || "") + expires + "; path=/; SameSite=Lax"; | |
| } | |
| function getCookie(name) { | |
| var nameEQ = name + "="; | |
| var ca = document.cookie.split(';'); | |
| for(var i=0;i < ca.length;i++) { | |
| var c = ca[i]; | |
| while (c.charAt(0)==' ') c = c.substring(1,c.length); | |
| if (c.indexOf(nameEQ) == 0) return c.substring(nameEQ.length,c.length); | |
| } | |
| return null; | |
| } | |
| (function() { | |
| const urlParams = new URLSearchParams(window.location.search); | |
| const writeToken = urlParams.get('write_token'); | |
| if (writeToken) { | |
| setCookie('trackio_write_token', writeToken, 7); | |
| // Only remove write_token from URL if not in iframe | |
| // In iframes, keep it in URL as cookies may be blocked | |
| const inIframe = window.self !== window.top; | |
| if (!inIframe) { | |
| urlParams.delete('write_token'); | |
| const newUrl = window.location.pathname + | |
| (urlParams.toString() ? '?' + urlParams.toString() : '') + | |
| window.location.hash; | |
| window.history.replaceState({}, document.title, newUrl); | |
| } | |
| } | |
| })(); | |
| </script> | |
| """ | |
| gr.set_static_paths(paths=[utils.MEDIA_DIR]) | |
| with gr.Blocks(title="Trackio Dashboard", css=css, head=javascript) as demo: | |
| with gr.Sidebar(open=False) as sidebar: | |
| logo = gr.Markdown( | |
| f""" | |
| <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'> | |
| <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'> | |
| """ | |
| ) | |
| project_dd = gr.Dropdown(label="Project", allow_custom_value=True) | |
| embed_code = gr.Code( | |
| label="Embed this view", | |
| max_lines=2, | |
| lines=2, | |
| language="html", | |
| visible=bool(os.environ.get("SPACE_HOST")), | |
| ) | |
| with gr.Group(): | |
| run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") | |
| run_group_by_dd = gr.Dropdown(label="Group by...", choices=[], value=None) | |
| grouped_runs_panel = gr.Group(visible=False) | |
| run_cb = gr.CheckboxGroup( | |
| label="Runs", | |
| choices=[], | |
| interactive=True, | |
| elem_id="run-cb", | |
| show_select_all=True, | |
| ) | |
| gr.HTML("<hr>") | |
| realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True) | |
| smoothing_slider = gr.Slider( | |
| label="Smoothing Factor", | |
| minimum=0, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| info="0 = no smoothing", | |
| ) | |
| x_axis_dd = gr.Dropdown( | |
| label="X-axis", | |
| choices=["step", "time"], | |
| value="step", | |
| ) | |
| log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False) | |
| metric_filter_tb = gr.Textbox( | |
| label="Metric Filter (regex)", | |
| placeholder="e.g., loss|ndcg@10|gpu", | |
| value="", | |
| info="Filter metrics using regex patterns. Leave empty to show all metrics.", | |
| ) | |
| navbar = gr.Navbar(value=[("Metrics", ""), ("Runs", "/runs")], main_page_name=False) | |
| timer = gr.Timer(value=1) | |
| metrics_subset = gr.State([]) | |
| selected_runs_from_url = gr.State([]) | |
| run_selection_state = gr.State(RunSelection()) | |
| gr.on( | |
| [demo.load], | |
| fn=configure, | |
| outputs=[ | |
| metrics_subset, | |
| sidebar, | |
| metric_filter_tb, | |
| selected_runs_from_url, | |
| navbar, | |
| ], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| gr.on( | |
| [demo.load], | |
| fn=fns.get_projects, | |
| outputs=project_dd, | |
| show_progress="hidden", | |
| queue=False, | |
| api_name=False, | |
| ) | |
| gr.on( | |
| [timer.tick], | |
| fn=refresh_runs, | |
| inputs=[project_dd, run_tb, run_selection_state, selected_runs_from_url], | |
| outputs=[run_cb, run_tb, run_selection_state], | |
| show_progress="hidden", | |
| api_name=False, | |
| ) | |
| gr.on( | |
| [timer.tick], | |
| fn=lambda: gr.Dropdown(info=fns.get_project_info()), | |
| outputs=[project_dd], | |
| show_progress="hidden", | |
| api_name=False, | |
| ) | |
| gr.on( | |
| [demo.load, project_dd.change], | |
| fn=refresh_runs, | |
| inputs=[project_dd, run_tb, run_selection_state, selected_runs_from_url], | |
| outputs=[run_cb, run_tb, run_selection_state], | |
| show_progress="hidden", | |
| queue=False, | |
| api_name=False, | |
| ).then( | |
| fn=update_x_axis_choices, | |
| inputs=[project_dd, run_selection_state], | |
| outputs=x_axis_dd, | |
| show_progress="hidden", | |
| queue=False, | |
| api_name=False, | |
| ).then( | |
| fn=generate_embed, | |
| inputs=[project_dd, metric_filter_tb, run_selection_state], | |
| outputs=[embed_code], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fns.update_navbar_value, | |
| inputs=[project_dd], | |
| outputs=[navbar], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=fns.get_group_by_fields, | |
| inputs=[project_dd], | |
| outputs=[run_group_by_dd], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| gr.on( | |
| [run_cb.input], | |
| fn=update_x_axis_choices, | |
| inputs=[project_dd, run_selection_state], | |
| outputs=x_axis_dd, | |
| show_progress="hidden", | |
| queue=False, | |
| api_name=False, | |
| ) | |
| gr.on( | |
| [metric_filter_tb.change, run_cb.change], | |
| fn=generate_embed, | |
| inputs=[project_dd, metric_filter_tb, run_selection_state], | |
| outputs=embed_code, | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| def toggle_group_view(group_by_dd): | |
| return ( | |
| gr.CheckboxGroup(visible=not bool(group_by_dd)), | |
| gr.Group(visible=bool(group_by_dd)), | |
| ) | |
| gr.on( | |
| [run_group_by_dd.change], | |
| fn=toggle_group_view, | |
| inputs=[run_group_by_dd], | |
| outputs=[run_cb, grouped_runs_panel], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| realtime_cb.change( | |
| fn=toggle_timer, | |
| inputs=realtime_cb, | |
| outputs=timer, | |
| api_name=False, | |
| queue=False, | |
| ) | |
| run_cb.input( | |
| fn=fns.handle_run_checkbox_change, | |
| inputs=[run_cb, run_selection_state], | |
| outputs=run_selection_state, | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=generate_embed, | |
| inputs=[project_dd, metric_filter_tb, run_selection_state], | |
| outputs=embed_code, | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| run_tb.input( | |
| fn=refresh_runs, | |
| inputs=[project_dd, run_tb, run_selection_state], | |
| outputs=[run_cb, run_tb, run_selection_state], | |
| api_name=False, | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| gr.api( | |
| fn=upload_db_to_space, | |
| api_name="upload_db_to_space", | |
| ) | |
| gr.api( | |
| fn=bulk_upload_media, | |
| api_name="bulk_upload_media", | |
| ) | |
| gr.api( | |
| fn=log, | |
| api_name="log", | |
| ) | |
| gr.api( | |
| fn=bulk_log, | |
| api_name="bulk_log", | |
| ) | |
| gr.api( | |
| fn=get_metric_values, | |
| api_name="get_metric_values", | |
| ) | |
| gr.api( | |
| fn=get_runs_for_project, | |
| api_name="get_runs_for_project", | |
| ) | |
| gr.api( | |
| fn=get_metrics_for_run, | |
| api_name="get_metrics_for_run", | |
| ) | |
| gr.api( | |
| fn=get_all_projects, | |
| api_name="get_all_projects", | |
| ) | |
| gr.api( | |
| fn=get_project_summary, | |
| api_name="get_project_summary", | |
| ) | |
| gr.api( | |
| fn=get_run_summary, | |
| api_name="get_run_summary", | |
| ) | |
| x_lim = gr.State(None) | |
| last_steps = gr.State({}) | |
| def update_x_lim(select_data: gr.SelectData): | |
| return select_data.index | |
| def update_last_steps(project): | |
| """Check the last step for each run to detect when new data is available.""" | |
| if not project: | |
| return {} | |
| return SQLiteStorage.get_max_steps_for_runs(project) | |
| timer.tick( | |
| fn=update_last_steps, | |
| inputs=[project_dd], | |
| outputs=last_steps, | |
| show_progress="hidden", | |
| api_name=False, | |
| ) | |
| def update_dashboard( | |
| project, | |
| runs, | |
| smoothing_granularity, | |
| metrics_subset, | |
| x_lim_value, | |
| x_axis, | |
| log_scale, | |
| metric_filter, | |
| ): | |
| dfs = [] | |
| images_by_run = {} | |
| original_runs = runs.copy() | |
| for run in runs: | |
| df, images_by_key = load_run_data( | |
| project, run, smoothing_granularity, x_axis, log_scale | |
| ) | |
| if df is not None: | |
| dfs.append(df) | |
| images_by_run[run] = images_by_key | |
| if dfs: | |
| if smoothing_granularity > 0: | |
| original_dfs = [] | |
| smoothed_dfs = [] | |
| for df in dfs: | |
| original_data = df[df["data_type"] == "original"] | |
| smoothed_data = df[df["data_type"] == "smoothed"] | |
| if not original_data.empty: | |
| original_dfs.append(original_data) | |
| if not smoothed_data.empty: | |
| smoothed_dfs.append(smoothed_data) | |
| all_dfs = original_dfs + smoothed_dfs | |
| master_df = ( | |
| pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame() | |
| ) | |
| else: | |
| master_df = pd.concat(dfs, ignore_index=True) | |
| else: | |
| master_df = pd.DataFrame() | |
| if master_df.empty: | |
| if space_id := utils.get_space(): | |
| gr.Markdown(INSTRUCTIONS_SPACES.format(space_id)) | |
| else: | |
| gr.Markdown(INSTRUCTIONS_LOCAL) | |
| return | |
| x_column = "step" | |
| if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns: | |
| x_column = dfs[0]["x_axis"].iloc[0] | |
| numeric_cols = master_df.select_dtypes(include="number").columns | |
| numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] | |
| if x_column and x_column in numeric_cols: | |
| numeric_cols.remove(x_column) | |
| if metrics_subset: | |
| numeric_cols = [c for c in numeric_cols if c in metrics_subset] | |
| if metric_filter and metric_filter.strip(): | |
| numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter) | |
| nested_metric_groups = utils.group_metrics_with_subprefixes(list(numeric_cols)) | |
| color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0) | |
| metric_idx = 0 | |
| for group_name in sorted(nested_metric_groups.keys()): | |
| group_data = nested_metric_groups[group_name] | |
| total_plot_count = sum( | |
| 1 | |
| for m in group_data["direct_metrics"] | |
| if not master_df.dropna(subset=[m]).empty | |
| ) + sum( | |
| sum(1 for m in metrics if not master_df.dropna(subset=[m]).empty) | |
| for metrics in group_data["subgroups"].values() | |
| ) | |
| group_label = ( | |
| f"{group_name} ({total_plot_count})" | |
| if total_plot_count > 0 | |
| else group_name | |
| ) | |
| with gr.Accordion( | |
| label=group_label, | |
| open=True, | |
| key=f"accordion-{group_name}", | |
| preserved_by_key=["value", "open"], | |
| ): | |
| if group_data["direct_metrics"]: | |
| with gr.Draggable( | |
| key=f"row-{group_name}-direct", orientation="row" | |
| ): | |
| for metric_name in group_data["direct_metrics"]: | |
| metric_df = master_df.dropna(subset=[metric_name]) | |
| color = "run" if "run" in metric_df.columns else None | |
| if not metric_df.empty: | |
| plot = gr.LinePlot( | |
| utils.downsample( | |
| metric_df, | |
| x_column, | |
| metric_name, | |
| color, | |
| x_lim_value, | |
| ), | |
| x=x_column, | |
| y=metric_name, | |
| y_title=metric_name.split("/")[-1], | |
| color=color, | |
| color_map=color_map, | |
| title=metric_name, | |
| key=f"plot-{metric_idx}", | |
| preserved_by_key=None, | |
| x_lim=x_lim_value, | |
| show_fullscreen_button=True, | |
| min_width=400, | |
| show_export_button=True, | |
| ) | |
| plot.select( | |
| update_x_lim, | |
| outputs=x_lim, | |
| key=f"select-{metric_idx}", | |
| ) | |
| plot.double_click( | |
| lambda: None, | |
| outputs=x_lim, | |
| key=f"double-{metric_idx}", | |
| ) | |
| metric_idx += 1 | |
| if group_data["subgroups"]: | |
| for subgroup_name in sorted(group_data["subgroups"].keys()): | |
| subgroup_metrics = group_data["subgroups"][subgroup_name] | |
| subgroup_plot_count = sum( | |
| 1 | |
| for m in subgroup_metrics | |
| if not master_df.dropna(subset=[m]).empty | |
| ) | |
| subgroup_label = ( | |
| f"{subgroup_name} ({subgroup_plot_count})" | |
| if subgroup_plot_count > 0 | |
| else subgroup_name | |
| ) | |
| with gr.Accordion( | |
| label=subgroup_label, | |
| open=True, | |
| key=f"accordion-{group_name}-{subgroup_name}", | |
| preserved_by_key=["value", "open"], | |
| ): | |
| with gr.Draggable(key=f"row-{group_name}-{subgroup_name}"): | |
| for metric_name in subgroup_metrics: | |
| metric_df = master_df.dropna(subset=[metric_name]) | |
| color = ( | |
| "run" if "run" in metric_df.columns else None | |
| ) | |
| if not metric_df.empty: | |
| plot = gr.LinePlot( | |
| utils.downsample( | |
| metric_df, | |
| x_column, | |
| metric_name, | |
| color, | |
| x_lim_value, | |
| ), | |
| x=x_column, | |
| y=metric_name, | |
| y_title=metric_name.split("/")[-1], | |
| color=color, | |
| color_map=color_map, | |
| title=metric_name, | |
| key=f"plot-{metric_idx}", | |
| preserved_by_key=None, | |
| x_lim=x_lim_value, | |
| show_fullscreen_button=True, | |
| min_width=400, | |
| show_export_button=True, | |
| ) | |
| plot.select( | |
| update_x_lim, | |
| outputs=x_lim, | |
| key=f"select-{metric_idx}", | |
| ) | |
| plot.double_click( | |
| lambda: None, | |
| outputs=x_lim, | |
| key=f"double-{metric_idx}", | |
| ) | |
| metric_idx += 1 | |
| if images_by_run and any(any(images) for images in images_by_run.values()): | |
| create_media_section(images_by_run) | |
| table_cols = master_df.select_dtypes(include="object").columns | |
| table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS] | |
| if metrics_subset: | |
| table_cols = [c for c in table_cols if c in metrics_subset] | |
| if metric_filter and metric_filter.strip(): | |
| table_cols = filter_metrics_by_regex(list(table_cols), metric_filter) | |
| actual_table_count = sum( | |
| 1 | |
| for metric_name in table_cols | |
| if not (metric_df := master_df.dropna(subset=[metric_name])).empty | |
| and isinstance(value := metric_df[metric_name].iloc[-1], dict) | |
| and value.get("_type") == Table.TYPE | |
| ) | |
| if actual_table_count > 0: | |
| with gr.Accordion(f"tables ({actual_table_count})", open=True): | |
| with gr.Row(key="row"): | |
| for metric_idx, metric_name in enumerate(table_cols): | |
| metric_df = master_df.dropna(subset=[metric_name]) | |
| if not metric_df.empty: | |
| value = metric_df[metric_name].iloc[-1] | |
| if ( | |
| isinstance(value, dict) | |
| and "_type" in value | |
| and value["_type"] == Table.TYPE | |
| ): | |
| try: | |
| df = pd.DataFrame(value["_value"]) | |
| gr.DataFrame( | |
| df, | |
| label=f"{metric_name} (latest)", | |
| key=f"table-{metric_idx}", | |
| wrap=True, | |
| ) | |
| except Exception as e: | |
| gr.Warning( | |
| f"Column {metric_name} failed to render as a table: {e}" | |
| ) | |
| with grouped_runs_panel: | |
| def render_grouped_runs(project, group_key, filter_text, selection): | |
| if not group_key: | |
| return | |
| selection = selection or RunSelection() | |
| groups = fns.group_runs_by_config(project, group_key, filter_text) | |
| for label, runs in groups.items(): | |
| ordered_current = utils.ordered_subset(runs, selection.selected) | |
| with gr.Group(): | |
| show_group_cb = gr.Checkbox( | |
| label="Show/Hide", | |
| value=bool(ordered_current), | |
| key=f"show-cb-{group_key}-{label}", | |
| preserved_by_key=["value"], | |
| ) | |
| with gr.Accordion( | |
| f"{label} ({len(runs)})", | |
| open=False, | |
| key=f"accordion-{group_key}-{label}", | |
| preserved_by_key=["open"], | |
| ): | |
| group_cb = gr.CheckboxGroup( | |
| choices=runs, | |
| value=ordered_current, | |
| show_label=False, | |
| key=f"group-cb-{group_key}-{label}", | |
| ) | |
| gr.on( | |
| [group_cb.change], | |
| fn=fns.handle_group_checkbox_change, | |
| inputs=[ | |
| group_cb, | |
| run_selection_state, | |
| gr.State(runs), | |
| ], | |
| outputs=[ | |
| run_selection_state, | |
| group_cb, | |
| run_cb, | |
| ], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| gr.on( | |
| [show_group_cb.change], | |
| fn=fns.handle_group_toggle, | |
| inputs=[ | |
| show_group_cb, | |
| run_selection_state, | |
| gr.State(runs), | |
| ], | |
| outputs=[run_selection_state, group_cb, run_cb], | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=False, | |
| ) | |
| with demo.route("Runs", show_in_navbar=False): | |
| run_page.render() | |
| with demo.route("Run", show_in_navbar=False): | |
| run_detail_page.render() | |
| write_token = secrets.token_urlsafe(32) | |
| demo.write_token = write_token | |
| run_page.write_token = write_token | |
| run_detail_page.write_token = write_token | |
| if __name__ == "__main__": | |
| demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True) | |