|
|
"""Convert Hugging Face models to ONNX format. |
|
|
|
|
|
This application provides a Streamlit interface for converting Hugging Face models |
|
|
to ONNX format using the Transformers.js conversion scripts. It handles: |
|
|
- Model conversion with optional trust_remote_code and output_attentions |
|
|
- Automatic task inference with fallback support |
|
|
- README generation with merged metadata from the original model |
|
|
- Upload to Hugging Face Hub |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import shutil |
|
|
import subprocess |
|
|
import sys |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import streamlit as st |
|
|
import yaml |
|
|
from huggingface_hub import HfApi, hf_hub_download, model_info, whoami |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
"""Application configuration containing authentication and path settings. |
|
|
|
|
|
Attributes: |
|
|
hf_token: Hugging Face API token (user token takes precedence over system token) |
|
|
hf_username: Hugging Face username associated with the token |
|
|
is_using_user_token: True if using a user-provided token, False if using system token |
|
|
hf_base_url: Base URL for Hugging Face Hub |
|
|
repo_path: Path to the bundled transformers.js repository |
|
|
""" |
|
|
|
|
|
hf_token: str |
|
|
hf_username: str |
|
|
is_using_user_token: bool |
|
|
hf_base_url: str = "https://huggingface.co" |
|
|
repo_path: Path = Path("./transformers.js") |
|
|
|
|
|
@classmethod |
|
|
def from_env(cls) -> "Config": |
|
|
"""Create configuration from environment variables and Streamlit session state. |
|
|
|
|
|
Priority order for tokens: |
|
|
1. User-provided token from Streamlit session (st.session_state.user_hf_token) |
|
|
2. System token from environment variable (HF_TOKEN) |
|
|
|
|
|
Returns: |
|
|
Config: Initialized configuration object |
|
|
|
|
|
Raises: |
|
|
ValueError: If no valid token is available |
|
|
""" |
|
|
system_token = os.getenv("HF_TOKEN") |
|
|
user_token = st.session_state.get("user_hf_token") |
|
|
|
|
|
|
|
|
if user_token: |
|
|
hf_username = whoami(token=user_token)["name"] |
|
|
else: |
|
|
hf_username = ( |
|
|
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"] |
|
|
) |
|
|
|
|
|
|
|
|
hf_token = user_token or system_token |
|
|
|
|
|
if not hf_token: |
|
|
raise ValueError( |
|
|
"When the user token is not provided, the system token must be set." |
|
|
) |
|
|
|
|
|
return cls( |
|
|
hf_token=hf_token, |
|
|
hf_username=hf_username, |
|
|
is_using_user_token=bool(user_token), |
|
|
) |
|
|
|
|
|
|
|
|
class ModelConverter: |
|
|
"""Handles model conversion to ONNX format and upload to Hugging Face Hub. |
|
|
|
|
|
This class manages the entire conversion workflow: |
|
|
1. Fetching original model metadata and README |
|
|
2. Running the ONNX conversion subprocess |
|
|
3. Generating an enhanced README with merged metadata |
|
|
4. Uploading the converted model to Hugging Face Hub |
|
|
|
|
|
Attributes: |
|
|
config: Application configuration containing tokens and paths |
|
|
api: Hugging Face API client for repository operations |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Config): |
|
|
"""Initialize the converter with configuration. |
|
|
|
|
|
Args: |
|
|
config: Application configuration object |
|
|
""" |
|
|
self.config = config |
|
|
self.api = HfApi(token=config.hf_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fetch_original_readme(self, repo_id: str) -> str: |
|
|
"""Download the README from the original model repository. |
|
|
|
|
|
Args: |
|
|
repo_id: Hugging Face model repository ID (e.g., 'username/model-name') |
|
|
|
|
|
Returns: |
|
|
str: Content of the README file, or empty string if not found |
|
|
""" |
|
|
try: |
|
|
readme_path = hf_hub_download( |
|
|
repo_id=repo_id, filename="README.md", token=self.config.hf_token |
|
|
) |
|
|
with open(readme_path, "r", encoding="utf-8", errors="ignore") as f: |
|
|
return f.read() |
|
|
except Exception: |
|
|
|
|
|
return "" |
|
|
|
|
|
def _strip_yaml_frontmatter(self, text: str) -> str: |
|
|
"""Remove YAML frontmatter from text, returning only the body. |
|
|
|
|
|
YAML frontmatter is delimited by '---' at the start and end. |
|
|
|
|
|
Args: |
|
|
text: Text that may contain YAML frontmatter |
|
|
|
|
|
Returns: |
|
|
str: Text with frontmatter removed, or original text if no frontmatter found |
|
|
""" |
|
|
if not text: |
|
|
return "" |
|
|
if text.startswith("---"): |
|
|
match = re.match(r"^---[\s\S]*?\n---\s*\n", text) |
|
|
if match: |
|
|
return text[match.end() :] |
|
|
return text |
|
|
|
|
|
def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]: |
|
|
"""Parse and extract YAML frontmatter from text. |
|
|
|
|
|
Args: |
|
|
text: Text that may contain YAML frontmatter |
|
|
|
|
|
Returns: |
|
|
Tuple containing: |
|
|
- dict: Parsed YAML frontmatter as a dictionary (empty dict if none found) |
|
|
- str: Remaining body text after the frontmatter |
|
|
""" |
|
|
if not text or not text.startswith("---"): |
|
|
return {}, text or "" |
|
|
|
|
|
|
|
|
match = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text) |
|
|
if not match: |
|
|
return {}, text |
|
|
|
|
|
frontmatter_text = match.group(1) |
|
|
body = text[match.end() :] |
|
|
|
|
|
|
|
|
try: |
|
|
parsed_data = yaml.safe_load(frontmatter_text) |
|
|
if not isinstance(parsed_data, dict): |
|
|
parsed_data = {} |
|
|
except Exception: |
|
|
parsed_data = {} |
|
|
|
|
|
return parsed_data, body |
|
|
|
|
|
def _get_pipeline_docs_url(self, pipeline_tag: Optional[str]) -> str: |
|
|
"""Generate Transformers.js documentation URL for a given pipeline tag. |
|
|
|
|
|
Args: |
|
|
pipeline_tag: Hugging Face pipeline tag (e.g., 'text-generation') |
|
|
|
|
|
Returns: |
|
|
str: URL to the relevant Transformers.js pipeline documentation |
|
|
""" |
|
|
base_url = "https://huggingface.co/docs/transformers.js/api/pipelines" |
|
|
|
|
|
if not pipeline_tag: |
|
|
return base_url |
|
|
|
|
|
|
|
|
pipeline_class_mapping = { |
|
|
"text-classification": "TextClassificationPipeline", |
|
|
"token-classification": "TokenClassificationPipeline", |
|
|
"question-answering": "QuestionAnsweringPipeline", |
|
|
"fill-mask": "FillMaskPipeline", |
|
|
"text2text-generation": "Text2TextGenerationPipeline", |
|
|
"summarization": "SummarizationPipeline", |
|
|
"translation": "TranslationPipeline", |
|
|
"text-generation": "TextGenerationPipeline", |
|
|
"zero-shot-classification": "ZeroShotClassificationPipeline", |
|
|
"feature-extraction": "FeatureExtractionPipeline", |
|
|
"image-feature-extraction": "ImageFeatureExtractionPipeline", |
|
|
"audio-classification": "AudioClassificationPipeline", |
|
|
"zero-shot-audio-classification": "ZeroShotAudioClassificationPipeline", |
|
|
"automatic-speech-recognition": "AutomaticSpeechRecognitionPipeline", |
|
|
"image-to-text": "ImageToTextPipeline", |
|
|
"image-classification": "ImageClassificationPipeline", |
|
|
"image-segmentation": "ImageSegmentationPipeline", |
|
|
"background-removal": "BackgroundRemovalPipeline", |
|
|
"zero-shot-image-classification": "ZeroShotImageClassificationPipeline", |
|
|
"object-detection": "ObjectDetectionPipeline", |
|
|
"zero-shot-object-detection": "ZeroShotObjectDetectionPipeline", |
|
|
"document-question-answering": "DocumentQuestionAnsweringPipeline", |
|
|
"text-to-audio": "TextToAudioPipeline", |
|
|
"image-to-image": "ImageToImagePipeline", |
|
|
"depth-estimation": "DepthEstimationPipeline", |
|
|
} |
|
|
|
|
|
pipeline_class = pipeline_class_mapping.get(pipeline_tag) |
|
|
if not pipeline_class: |
|
|
return base_url |
|
|
|
|
|
return f"{base_url}#module_pipelines.{pipeline_class}" |
|
|
|
|
|
def _normalize_pipeline_tag(self, pipeline_tag: Optional[str]) -> Optional[str]: |
|
|
"""Normalize pipeline tag to match expected task names. |
|
|
|
|
|
Some pipeline tags use abbreviations that need to be expanded |
|
|
for the conversion script to recognize them. |
|
|
|
|
|
Args: |
|
|
pipeline_tag: Original pipeline tag from model metadata |
|
|
|
|
|
Returns: |
|
|
Optional[str]: Normalized task name, or None if input is None |
|
|
""" |
|
|
if not pipeline_tag: |
|
|
return None |
|
|
|
|
|
|
|
|
tag_synonyms = { |
|
|
"vqa": "visual-question-answering", |
|
|
} |
|
|
|
|
|
return tag_synonyms.get(pipeline_tag, pipeline_tag) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_repository(self) -> None: |
|
|
"""Verify that the transformers.js repository exists. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If the repository is not found at the expected path |
|
|
""" |
|
|
if not self.config.repo_path.exists(): |
|
|
raise RuntimeError( |
|
|
f"Expected transformers.js repository at {self.config.repo_path} " |
|
|
f"but it was not found." |
|
|
) |
|
|
|
|
|
def _run_conversion_subprocess( |
|
|
self, input_model_id: str, extra_args: Optional[List[str]] = None |
|
|
) -> subprocess.CompletedProcess: |
|
|
"""Execute the ONNX conversion script as a subprocess. |
|
|
|
|
|
Args: |
|
|
input_model_id: Hugging Face model ID to convert |
|
|
extra_args: Additional command-line arguments for the conversion script |
|
|
|
|
|
Returns: |
|
|
subprocess.CompletedProcess: Result of the subprocess execution |
|
|
""" |
|
|
|
|
|
command = [ |
|
|
sys.executable, |
|
|
"-m", |
|
|
"scripts.convert", |
|
|
"--quantize", |
|
|
"--model_id", |
|
|
input_model_id, |
|
|
] |
|
|
|
|
|
if extra_args: |
|
|
command.extend(extra_args) |
|
|
|
|
|
|
|
|
return subprocess.run( |
|
|
command, |
|
|
cwd=self.config.repo_path, |
|
|
capture_output=True, |
|
|
text=True, |
|
|
env={ |
|
|
"HF_TOKEN": self.config.hf_token, |
|
|
}, |
|
|
) |
|
|
|
|
|
def convert_model( |
|
|
self, |
|
|
input_model_id: str, |
|
|
trust_remote_code: bool = False, |
|
|
output_attentions: bool = False, |
|
|
) -> Tuple[bool, Optional[str]]: |
|
|
"""Convert a Hugging Face model to ONNX format. |
|
|
|
|
|
Args: |
|
|
input_model_id: Hugging Face model repository ID |
|
|
trust_remote_code: Whether to trust and execute remote code from the model |
|
|
output_attentions: Whether to output attention weights (required for some tasks) |
|
|
|
|
|
Returns: |
|
|
Tuple containing: |
|
|
- bool: True if conversion succeeded, False otherwise |
|
|
- Optional[str]: Error message if failed, or conversion log if succeeded |
|
|
""" |
|
|
try: |
|
|
conversion_args: List[str] = [] |
|
|
|
|
|
|
|
|
if trust_remote_code: |
|
|
if not self.config.is_using_user_token: |
|
|
raise Exception( |
|
|
"Trust Remote Code requires your own HuggingFace token." |
|
|
) |
|
|
conversion_args.append("--trust_remote_code") |
|
|
|
|
|
|
|
|
if output_attentions: |
|
|
conversion_args.append("--output_attentions") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
info = model_info(repo_id=input_model_id, token=self.config.hf_token) |
|
|
pipeline_tag = getattr(info, "pipeline_tag", None) |
|
|
task = self._normalize_pipeline_tag(pipeline_tag) |
|
|
if task: |
|
|
conversion_args.extend(["--task", task]) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
result = self._run_conversion_subprocess( |
|
|
input_model_id, extra_args=conversion_args or None |
|
|
) |
|
|
|
|
|
|
|
|
if result.returncode != 0: |
|
|
return False, result.stderr |
|
|
|
|
|
return True, result.stderr |
|
|
|
|
|
except Exception as e: |
|
|
return False, str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]: |
|
|
"""Upload the converted ONNX model to Hugging Face Hub. |
|
|
|
|
|
This method: |
|
|
1. Creates the target repository (if it doesn't exist) |
|
|
2. Generates an enhanced README with merged metadata |
|
|
3. Uploads all model files to the repository |
|
|
4. Cleans up local files after upload |
|
|
|
|
|
Args: |
|
|
input_model_id: Original model repository ID |
|
|
output_model_id: Target repository ID for the ONNX model |
|
|
|
|
|
Returns: |
|
|
Optional[str]: Error message if upload failed, None if successful |
|
|
""" |
|
|
model_folder_path = self.config.repo_path / "models" / input_model_id |
|
|
|
|
|
try: |
|
|
|
|
|
self.api.create_repo(output_model_id, exist_ok=True, private=False) |
|
|
|
|
|
|
|
|
readme_path = model_folder_path / "README.md" |
|
|
readme_content = self.generate_readme(input_model_id) |
|
|
readme_path.write_text(readme_content, encoding="utf-8") |
|
|
|
|
|
|
|
|
self.api.upload_folder( |
|
|
folder_path=str(model_folder_path), repo_id=output_model_id |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
return str(e) |
|
|
finally: |
|
|
|
|
|
shutil.rmtree(model_folder_path, ignore_errors=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_readme(self, input_model_id: str) -> str: |
|
|
"""Generate an enhanced README for the ONNX model. |
|
|
|
|
|
This method creates a README that: |
|
|
1. Merges metadata from the original model with ONNX-specific metadata |
|
|
2. Adds a description and link to the conversion space |
|
|
3. Includes usage instructions with links to Transformers.js docs |
|
|
4. Appends the original model's README content |
|
|
|
|
|
Args: |
|
|
input_model_id: Original model repository ID |
|
|
|
|
|
Returns: |
|
|
str: Complete README content in Markdown format with YAML frontmatter |
|
|
""" |
|
|
|
|
|
try: |
|
|
info = model_info(repo_id=input_model_id, token=self.config.hf_token) |
|
|
pipeline_tag = getattr(info, "pipeline_tag", None) |
|
|
except Exception: |
|
|
pipeline_tag = None |
|
|
|
|
|
|
|
|
original_text = self._fetch_original_readme(input_model_id) |
|
|
original_meta, original_body = self._extract_yaml_frontmatter(original_text) |
|
|
original_body = ( |
|
|
original_body or self._strip_yaml_frontmatter(original_text) |
|
|
).strip() |
|
|
|
|
|
|
|
|
merged_meta = {} |
|
|
if isinstance(original_meta, dict): |
|
|
merged_meta.update(original_meta) |
|
|
merged_meta["library_name"] = "transformers.js" |
|
|
merged_meta["base_model"] = [input_model_id] |
|
|
if pipeline_tag is not None: |
|
|
merged_meta["pipeline_tag"] = pipeline_tag |
|
|
|
|
|
|
|
|
frontmatter_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip() |
|
|
header = f"---\n{frontmatter_yaml}\n---\n\n" |
|
|
|
|
|
|
|
|
readme_sections: List[str] = [] |
|
|
readme_sections.append(header) |
|
|
|
|
|
|
|
|
model_name = input_model_id.split("/")[-1] |
|
|
readme_sections.append(f"# {model_name} (ONNX)\n") |
|
|
|
|
|
|
|
|
readme_sections.append( |
|
|
f"This is an ONNX version of [{input_model_id}](https://huggingface.co/{input_model_id}). " |
|
|
"It was automatically converted and uploaded using " |
|
|
"[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)." |
|
|
) |
|
|
|
|
|
|
|
|
docs_url = self._get_pipeline_docs_url(pipeline_tag) |
|
|
if docs_url: |
|
|
readme_sections.append("\n## Usage with Transformers.js\n") |
|
|
if pipeline_tag: |
|
|
readme_sections.append( |
|
|
f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}" |
|
|
) |
|
|
else: |
|
|
readme_sections.append(f"See the pipelines documentation: {docs_url}") |
|
|
|
|
|
|
|
|
if original_body: |
|
|
readme_sections.append("\n---\n") |
|
|
readme_sections.append(original_body) |
|
|
|
|
|
return "\n\n".join(readme_sections) + "\n" |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main application entry point for the Streamlit interface. |
|
|
|
|
|
This function: |
|
|
1. Initializes configuration and converter |
|
|
2. Displays the UI for model input and options |
|
|
3. Handles the conversion workflow |
|
|
4. Shows progress and results to the user |
|
|
""" |
|
|
st.write("## Convert a Hugging Face model to ONNX") |
|
|
|
|
|
try: |
|
|
|
|
|
config = Config.from_env() |
|
|
converter = ModelConverter(config) |
|
|
converter.setup_repository() |
|
|
|
|
|
|
|
|
input_model_id = st.text_input( |
|
|
"Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`" |
|
|
) |
|
|
|
|
|
if not input_model_id: |
|
|
return |
|
|
|
|
|
|
|
|
st.text_input( |
|
|
f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.", |
|
|
type="password", |
|
|
key="user_hf_token", |
|
|
) |
|
|
|
|
|
|
|
|
trust_remote_code = st.toggle("Optional: Trust Remote Code.") |
|
|
if trust_remote_code: |
|
|
st.warning( |
|
|
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token." |
|
|
) |
|
|
|
|
|
|
|
|
output_attentions = False |
|
|
if "whisper" in input_model_id.lower(): |
|
|
output_attentions = st.toggle( |
|
|
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if config.hf_username == input_model_id.split("/")[0]: |
|
|
same_repo = st.checkbox( |
|
|
"Upload the ONNX weights to the existing repository" |
|
|
) |
|
|
else: |
|
|
same_repo = False |
|
|
|
|
|
model_name = input_model_id.split("/")[-1] |
|
|
output_model_id = f"{config.hf_username}/{model_name}" |
|
|
|
|
|
|
|
|
if not same_repo: |
|
|
output_model_id += "-ONNX" |
|
|
|
|
|
output_model_url = f"{config.hf_base_url}/{output_model_id}" |
|
|
|
|
|
|
|
|
if not same_repo and converter.api.repo_exists(output_model_id): |
|
|
st.write("This model has already been converted! 🎉") |
|
|
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") |
|
|
return |
|
|
|
|
|
|
|
|
st.write(f"URL where the model will be converted and uploaded to:") |
|
|
st.code(output_model_url, language="plaintext") |
|
|
|
|
|
|
|
|
if not st.button(label="Proceed", type="primary"): |
|
|
return |
|
|
|
|
|
|
|
|
with st.spinner("Converting model..."): |
|
|
success, stderr = converter.convert_model( |
|
|
input_model_id, |
|
|
trust_remote_code=trust_remote_code, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
if not success: |
|
|
st.error(f"Conversion failed: {stderr}") |
|
|
return |
|
|
|
|
|
st.success("Conversion successful!") |
|
|
st.code(stderr) |
|
|
|
|
|
|
|
|
with st.spinner("Uploading model..."): |
|
|
error = converter.upload_model(input_model_id, output_model_id) |
|
|
if error: |
|
|
st.error(f"Upload failed: {error}") |
|
|
return |
|
|
|
|
|
st.success("Upload successful!") |
|
|
st.write("You can now go and view the model on Hugging Face!") |
|
|
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Application error") |
|
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|