convert-to-onnx / app.py
Felladrin's picture
Remove unnecessary environment variables
77cbaf0
"""Convert Hugging Face models to ONNX format.
This application provides a Streamlit interface for converting Hugging Face models
to ONNX format using the Transformers.js conversion scripts. It handles:
- Model conversion with optional trust_remote_code and output_attentions
- Automatic task inference with fallback support
- README generation with merged metadata from the original model
- Upload to Hugging Face Hub
"""
import logging
import os
import re
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import streamlit as st
import yaml
from huggingface_hub import HfApi, hf_hub_download, model_info, whoami
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class Config:
"""Application configuration containing authentication and path settings.
Attributes:
hf_token: Hugging Face API token (user token takes precedence over system token)
hf_username: Hugging Face username associated with the token
is_using_user_token: True if using a user-provided token, False if using system token
hf_base_url: Base URL for Hugging Face Hub
repo_path: Path to the bundled transformers.js repository
"""
hf_token: str
hf_username: str
is_using_user_token: bool
hf_base_url: str = "https://huggingface.co"
repo_path: Path = Path("./transformers.js")
@classmethod
def from_env(cls) -> "Config":
"""Create configuration from environment variables and Streamlit session state.
Priority order for tokens:
1. User-provided token from Streamlit session (st.session_state.user_hf_token)
2. System token from environment variable (HF_TOKEN)
Returns:
Config: Initialized configuration object
Raises:
ValueError: If no valid token is available
"""
system_token = os.getenv("HF_TOKEN")
user_token = st.session_state.get("user_hf_token")
# Determine username based on which token is being used
if user_token:
hf_username = whoami(token=user_token)["name"]
else:
hf_username = (
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
)
# User token takes precedence over system token
hf_token = user_token or system_token
if not hf_token:
raise ValueError(
"When the user token is not provided, the system token must be set."
)
return cls(
hf_token=hf_token,
hf_username=hf_username,
is_using_user_token=bool(user_token),
)
class ModelConverter:
"""Handles model conversion to ONNX format and upload to Hugging Face Hub.
This class manages the entire conversion workflow:
1. Fetching original model metadata and README
2. Running the ONNX conversion subprocess
3. Generating an enhanced README with merged metadata
4. Uploading the converted model to Hugging Face Hub
Attributes:
config: Application configuration containing tokens and paths
api: Hugging Face API client for repository operations
"""
def __init__(self, config: Config):
"""Initialize the converter with configuration.
Args:
config: Application configuration object
"""
self.config = config
self.api = HfApi(token=config.hf_token)
# ============================================================================
# README Processing Methods
# ============================================================================
def _fetch_original_readme(self, repo_id: str) -> str:
"""Download the README from the original model repository.
Args:
repo_id: Hugging Face model repository ID (e.g., 'username/model-name')
Returns:
str: Content of the README file, or empty string if not found
"""
try:
readme_path = hf_hub_download(
repo_id=repo_id, filename="README.md", token=self.config.hf_token
)
with open(readme_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except Exception:
# Silently fail if README doesn't exist or can't be downloaded
return ""
def _strip_yaml_frontmatter(self, text: str) -> str:
"""Remove YAML frontmatter from text, returning only the body.
YAML frontmatter is delimited by '---' at the start and end.
Args:
text: Text that may contain YAML frontmatter
Returns:
str: Text with frontmatter removed, or original text if no frontmatter found
"""
if not text:
return ""
if text.startswith("---"):
match = re.match(r"^---[\s\S]*?\n---\s*\n", text)
if match:
return text[match.end() :]
return text
def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
"""Parse and extract YAML frontmatter from text.
Args:
text: Text that may contain YAML frontmatter
Returns:
Tuple containing:
- dict: Parsed YAML frontmatter as a dictionary (empty dict if none found)
- str: Remaining body text after the frontmatter
"""
if not text or not text.startswith("---"):
return {}, text or ""
# Match YAML frontmatter pattern: ---\n...content...\n---\n
match = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text)
if not match:
return {}, text
frontmatter_text = match.group(1)
body = text[match.end() :]
# Parse YAML safely, returning empty dict on any error
try:
parsed_data = yaml.safe_load(frontmatter_text)
if not isinstance(parsed_data, dict):
parsed_data = {}
except Exception:
parsed_data = {}
return parsed_data, body
def _get_pipeline_docs_url(self, pipeline_tag: Optional[str]) -> str:
"""Generate Transformers.js documentation URL for a given pipeline tag.
Args:
pipeline_tag: Hugging Face pipeline tag (e.g., 'text-generation')
Returns:
str: URL to the relevant Transformers.js pipeline documentation
"""
base_url = "https://huggingface.co/docs/transformers.js/api/pipelines"
if not pipeline_tag:
return base_url
# Map Hugging Face pipeline tags to Transformers.js pipeline class names
pipeline_class_mapping = {
"text-classification": "TextClassificationPipeline",
"token-classification": "TokenClassificationPipeline",
"question-answering": "QuestionAnsweringPipeline",
"fill-mask": "FillMaskPipeline",
"text2text-generation": "Text2TextGenerationPipeline",
"summarization": "SummarizationPipeline",
"translation": "TranslationPipeline",
"text-generation": "TextGenerationPipeline",
"zero-shot-classification": "ZeroShotClassificationPipeline",
"feature-extraction": "FeatureExtractionPipeline",
"image-feature-extraction": "ImageFeatureExtractionPipeline",
"audio-classification": "AudioClassificationPipeline",
"zero-shot-audio-classification": "ZeroShotAudioClassificationPipeline",
"automatic-speech-recognition": "AutomaticSpeechRecognitionPipeline",
"image-to-text": "ImageToTextPipeline",
"image-classification": "ImageClassificationPipeline",
"image-segmentation": "ImageSegmentationPipeline",
"background-removal": "BackgroundRemovalPipeline",
"zero-shot-image-classification": "ZeroShotImageClassificationPipeline",
"object-detection": "ObjectDetectionPipeline",
"zero-shot-object-detection": "ZeroShotObjectDetectionPipeline",
"document-question-answering": "DocumentQuestionAnsweringPipeline",
"text-to-audio": "TextToAudioPipeline",
"image-to-image": "ImageToImagePipeline",
"depth-estimation": "DepthEstimationPipeline",
}
pipeline_class = pipeline_class_mapping.get(pipeline_tag)
if not pipeline_class:
return base_url
return f"{base_url}#module_pipelines.{pipeline_class}"
def _normalize_pipeline_tag(self, pipeline_tag: Optional[str]) -> Optional[str]:
"""Normalize pipeline tag to match expected task names.
Some pipeline tags use abbreviations that need to be expanded
for the conversion script to recognize them.
Args:
pipeline_tag: Original pipeline tag from model metadata
Returns:
Optional[str]: Normalized task name, or None if input is None
"""
if not pipeline_tag:
return None
# Map abbreviated tags to their full names
tag_synonyms = {
"vqa": "visual-question-answering",
}
return tag_synonyms.get(pipeline_tag, pipeline_tag)
# ============================================================================
# Model Conversion Methods
# ============================================================================
def setup_repository(self) -> None:
"""Verify that the transformers.js repository exists.
Raises:
RuntimeError: If the repository is not found at the expected path
"""
if not self.config.repo_path.exists():
raise RuntimeError(
f"Expected transformers.js repository at {self.config.repo_path} "
f"but it was not found."
)
def _run_conversion_subprocess(
self, input_model_id: str, extra_args: Optional[List[str]] = None
) -> subprocess.CompletedProcess:
"""Execute the ONNX conversion script as a subprocess.
Args:
input_model_id: Hugging Face model ID to convert
extra_args: Additional command-line arguments for the conversion script
Returns:
subprocess.CompletedProcess: Result of the subprocess execution
"""
# Build the conversion command
command = [
sys.executable,
"-m",
"scripts.convert",
"--quantize",
"--model_id",
input_model_id,
]
if extra_args:
command.extend(extra_args)
# Run conversion in the transformers.js repository directory
return subprocess.run(
command,
cwd=self.config.repo_path,
capture_output=True,
text=True,
env={
"HF_TOKEN": self.config.hf_token,
},
)
def convert_model(
self,
input_model_id: str,
trust_remote_code: bool = False,
output_attentions: bool = False,
) -> Tuple[bool, Optional[str]]:
"""Convert a Hugging Face model to ONNX format.
Args:
input_model_id: Hugging Face model repository ID
trust_remote_code: Whether to trust and execute remote code from the model
output_attentions: Whether to output attention weights (required for some tasks)
Returns:
Tuple containing:
- bool: True if conversion succeeded, False otherwise
- Optional[str]: Error message if failed, or conversion log if succeeded
"""
try:
conversion_args: List[str] = []
# Handle trust_remote_code option (requires user token for security)
if trust_remote_code:
if not self.config.is_using_user_token:
raise Exception(
"Trust Remote Code requires your own HuggingFace token."
)
conversion_args.append("--trust_remote_code")
# Handle output_attentions option (needed for word-level timestamps in Whisper)
if output_attentions:
conversion_args.append("--output_attentions")
# Try to infer the task from model metadata and pass it to the conversion script
# This helps the script choose the right export configuration
try:
info = model_info(repo_id=input_model_id, token=self.config.hf_token)
pipeline_tag = getattr(info, "pipeline_tag", None)
task = self._normalize_pipeline_tag(pipeline_tag)
if task:
conversion_args.extend(["--task", task])
except Exception:
# If we can't fetch the task, continue without it
# The conversion script will try to infer it automatically
pass
# Run the conversion
result = self._run_conversion_subprocess(
input_model_id, extra_args=conversion_args or None
)
# Check if conversion succeeded
if result.returncode != 0:
return False, result.stderr
return True, result.stderr
except Exception as e:
return False, str(e)
# ============================================================================
# Upload Methods
# ============================================================================
def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
"""Upload the converted ONNX model to Hugging Face Hub.
This method:
1. Creates the target repository (if it doesn't exist)
2. Generates an enhanced README with merged metadata
3. Uploads all model files to the repository
4. Cleans up local files after upload
Args:
input_model_id: Original model repository ID
output_model_id: Target repository ID for the ONNX model
Returns:
Optional[str]: Error message if upload failed, None if successful
"""
model_folder_path = self.config.repo_path / "models" / input_model_id
try:
# Create the target repository (public by default)
self.api.create_repo(output_model_id, exist_ok=True, private=False)
# Generate and write the enhanced README
readme_path = model_folder_path / "README.md"
readme_content = self.generate_readme(input_model_id)
readme_path.write_text(readme_content, encoding="utf-8")
# Upload all files from the model folder
self.api.upload_folder(
folder_path=str(model_folder_path), repo_id=output_model_id
)
return None # Success
except Exception as e:
return str(e)
finally:
# Always clean up local files, even if upload failed
shutil.rmtree(model_folder_path, ignore_errors=True)
# ============================================================================
# README Generation Methods
# ============================================================================
def generate_readme(self, input_model_id: str) -> str:
"""Generate an enhanced README for the ONNX model.
This method creates a README that:
1. Merges metadata from the original model with ONNX-specific metadata
2. Adds a description and link to the conversion space
3. Includes usage instructions with links to Transformers.js docs
4. Appends the original model's README content
Args:
input_model_id: Original model repository ID
Returns:
str: Complete README content in Markdown format with YAML frontmatter
"""
# Fetch pipeline tag from model metadata (if available)
try:
info = model_info(repo_id=input_model_id, token=self.config.hf_token)
pipeline_tag = getattr(info, "pipeline_tag", None)
except Exception:
pipeline_tag = None
# Fetch and parse the original README
original_text = self._fetch_original_readme(input_model_id)
original_meta, original_body = self._extract_yaml_frontmatter(original_text)
original_body = (
original_body or self._strip_yaml_frontmatter(original_text)
).strip()
# Merge original metadata with our ONNX-specific metadata (ours take precedence)
merged_meta = {}
if isinstance(original_meta, dict):
merged_meta.update(original_meta)
merged_meta["library_name"] = "transformers.js"
merged_meta["base_model"] = [input_model_id]
if pipeline_tag is not None:
merged_meta["pipeline_tag"] = pipeline_tag
# Generate YAML frontmatter
frontmatter_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip()
header = f"---\n{frontmatter_yaml}\n---\n\n"
# Build README sections
readme_sections: List[str] = []
readme_sections.append(header)
# Add title
model_name = input_model_id.split("/")[-1]
readme_sections.append(f"# {model_name} (ONNX)\n")
# Add description
readme_sections.append(
f"This is an ONNX version of [{input_model_id}](https://huggingface.co/{input_model_id}). "
"It was automatically converted and uploaded using "
"[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
)
# Add usage section with Transformers.js docs link
docs_url = self._get_pipeline_docs_url(pipeline_tag)
if docs_url:
readme_sections.append("\n## Usage with Transformers.js\n")
if pipeline_tag:
readme_sections.append(
f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
)
else:
readme_sections.append(f"See the pipelines documentation: {docs_url}")
# Append original README content (if available)
if original_body:
readme_sections.append("\n---\n")
readme_sections.append(original_body)
return "\n\n".join(readme_sections) + "\n"
def main():
"""Main application entry point for the Streamlit interface.
This function:
1. Initializes configuration and converter
2. Displays the UI for model input and options
3. Handles the conversion workflow
4. Shows progress and results to the user
"""
st.write("## Convert a Hugging Face model to ONNX")
try:
# Initialize configuration and converter
config = Config.from_env()
converter = ModelConverter(config)
converter.setup_repository()
# Get model ID from user
input_model_id = st.text_input(
"Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
)
if not input_model_id:
return
# Optional: User token input
st.text_input(
f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
type="password",
key="user_hf_token",
)
# Optional: Trust remote code toggle (requires user token)
trust_remote_code = st.toggle("Optional: Trust Remote Code.")
if trust_remote_code:
st.warning(
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
)
# Optional: Output attentions (for Whisper models)
output_attentions = False
if "whisper" in input_model_id.lower():
output_attentions = st.toggle(
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
)
# Determine output repository
# If user owns the model, allow uploading to the same repo
if config.hf_username == input_model_id.split("/")[0]:
same_repo = st.checkbox(
"Upload the ONNX weights to the existing repository"
)
else:
same_repo = False
model_name = input_model_id.split("/")[-1]
output_model_id = f"{config.hf_username}/{model_name}"
# Add -ONNX suffix if creating a new repository
if not same_repo:
output_model_id += "-ONNX"
output_model_url = f"{config.hf_base_url}/{output_model_id}"
# Check if model already exists
if not same_repo and converter.api.repo_exists(output_model_id):
st.write("This model has already been converted! 🎉")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
return
# Show where the model will be uploaded
st.write(f"URL where the model will be converted and uploaded to:")
st.code(output_model_url, language="plaintext")
# Wait for user confirmation before proceeding
if not st.button(label="Proceed", type="primary"):
return
# Step 1: Convert the model to ONNX
with st.spinner("Converting model..."):
success, stderr = converter.convert_model(
input_model_id,
trust_remote_code=trust_remote_code,
output_attentions=output_attentions,
)
if not success:
st.error(f"Conversion failed: {stderr}")
return
st.success("Conversion successful!")
st.code(stderr)
# Step 2: Upload the converted model to Hugging Face
with st.spinner("Uploading model..."):
error = converter.upload_model(input_model_id, output_model_id)
if error:
st.error(f"Upload failed: {error}")
return
st.success("Upload successful!")
st.write("You can now go and view the model on Hugging Face!")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
except Exception as e:
logger.exception("Application error")
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()