Improve the generated Readme with the original model datacard and usage reference from transformers.js docs
Browse files- app.py +138 -13
 - requirements.txt +1 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -3,12 +3,14 @@ import os 
     | 
|
| 3 | 
         
             
            import subprocess
         
     | 
| 4 | 
         
             
            import sys
         
     | 
| 5 | 
         
             
            import shutil
         
     | 
| 
         | 
|
| 6 | 
         
             
            from pathlib import Path
         
     | 
| 7 | 
         
             
            from typing import List, Optional, Tuple
         
     | 
| 8 | 
         
             
            from dataclasses import dataclass
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            import streamlit as st
         
     | 
| 11 | 
         
            -
            from huggingface_hub import HfApi, whoami
         
     | 
| 
         | 
|
| 12 | 
         | 
| 13 | 
         
             
            logging.basicConfig(level=logging.INFO)
         
     | 
| 14 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
         @@ -58,6 +60,86 @@ class ModelConverter: 
     | 
|
| 58 | 
         
             
                    self.config = config
         
     | 
| 59 | 
         
             
                    self.api = HfApi(token=config.hf_token)
         
     | 
| 60 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 61 | 
         
             
                def setup_repository(self) -> None:
         
     | 
| 62 | 
         
             
                    """Ensure the bundled transformers.js repository is present."""
         
     | 
| 63 | 
         
             
                    if not self.config.repo_path.exists():
         
     | 
| 
         @@ -112,6 +194,14 @@ class ModelConverter: 
     | 
|
| 112 | 
         
             
                        if output_attentions:
         
     | 
| 113 | 
         
             
                            extra_args.append("--output_attentions")
         
     | 
| 114 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 115 | 
         
             
                        result = self._run_conversion_subprocess(
         
     | 
| 116 | 
         
             
                            input_model_id, extra_args=extra_args or None
         
     | 
| 117 | 
         
             
                        )
         
     | 
| 
         @@ -133,9 +223,8 @@ class ModelConverter: 
     | 
|
| 133 | 
         | 
| 134 | 
         
             
                        readme_path = f"{model_folder_path}/README.md"
         
     | 
| 135 | 
         | 
| 136 | 
         
            -
                         
     | 
| 137 | 
         
            -
                             
     | 
| 138 | 
         
            -
                                file.write(self.generate_readme(input_model_id))
         
     | 
| 139 | 
         | 
| 140 | 
         
             
                        self.api.upload_folder(
         
     | 
| 141 | 
         
             
                            folder_path=str(model_folder_path), repo_id=output_model_id
         
     | 
| 
         @@ -147,18 +236,54 @@ class ModelConverter: 
     | 
|
| 147 | 
         
             
                        shutil.rmtree(model_folder_path, ignore_errors=True)
         
     | 
| 148 | 
         | 
| 149 | 
         
             
                def generate_readme(self, imi: str):
         
     | 
| 150 | 
         
            -
                     
     | 
| 151 | 
         
            -
                         
     | 
| 152 | 
         
            -
                        " 
     | 
| 153 | 
         
            -
             
     | 
| 154 | 
         
            -
                         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 157 | 
         
             
                        f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
         
     | 
| 158 | 
         
             
                        "It was automatically converted and uploaded using "
         
     | 
| 159 | 
         
            -
                        "[this  
     | 
| 160 | 
         
             
                    )
         
     | 
| 161 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 162 | 
         | 
| 163 | 
         
             
            def main():
         
     | 
| 164 | 
         
             
                """Main application entry point."""
         
     | 
| 
         @@ -195,7 +320,7 @@ def main(): 
     | 
|
| 195 | 
         | 
| 196 | 
         
             
                    if config.hf_username == input_model_id.split("/")[0]:
         
     | 
| 197 | 
         
             
                        same_repo = st.checkbox(
         
     | 
| 198 | 
         
            -
                            " 
     | 
| 199 | 
         
             
                        )
         
     | 
| 200 | 
         
             
                    else:
         
     | 
| 201 | 
         
             
                        same_repo = False
         
     | 
| 
         | 
|
| 3 | 
         
             
            import subprocess
         
     | 
| 4 | 
         
             
            import sys
         
     | 
| 5 | 
         
             
            import shutil
         
     | 
| 6 | 
         
            +
            import re
         
     | 
| 7 | 
         
             
            from pathlib import Path
         
     | 
| 8 | 
         
             
            from typing import List, Optional, Tuple
         
     | 
| 9 | 
         
             
            from dataclasses import dataclass
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            import streamlit as st
         
     | 
| 12 | 
         
            +
            from huggingface_hub import HfApi, whoami, model_info, hf_hub_download
         
     | 
| 13 | 
         
            +
            import yaml
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            logging.basicConfig(level=logging.INFO)
         
     | 
| 16 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
         | 
|
| 60 | 
         
             
                    self.config = config
         
     | 
| 61 | 
         
             
                    self.api = HfApi(token=config.hf_token)
         
     | 
| 62 | 
         | 
| 63 | 
         
            +
                def _fetch_original_readme(self, repo_id: str) -> str:
         
     | 
| 64 | 
         
            +
                    try:
         
     | 
| 65 | 
         
            +
                        path = hf_hub_download(
         
     | 
| 66 | 
         
            +
                            repo_id=repo_id, filename="README.md", token=self.config.hf_token
         
     | 
| 67 | 
         
            +
                        )
         
     | 
| 68 | 
         
            +
                        with open(path, "r", encoding="utf-8", errors="ignore") as f:
         
     | 
| 69 | 
         
            +
                            return f.read()
         
     | 
| 70 | 
         
            +
                    except Exception:
         
     | 
| 71 | 
         
            +
                        return ""
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def _strip_yaml_frontmatter(self, text: str) -> str:
         
     | 
| 74 | 
         
            +
                    if not text:
         
     | 
| 75 | 
         
            +
                        return ""
         
     | 
| 76 | 
         
            +
                    if text.startswith("---"):
         
     | 
| 77 | 
         
            +
                        m = re.match(r"^---[\s\S]*?\n---\s*\n", text)
         
     | 
| 78 | 
         
            +
                        if m:
         
     | 
| 79 | 
         
            +
                            return text[m.end() :]
         
     | 
| 80 | 
         
            +
                    return text
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
         
     | 
| 83 | 
         
            +
                    """Return (frontmatter_dict, body). If no frontmatter, returns ({}, text)."""
         
     | 
| 84 | 
         
            +
                    if not text or not text.startswith("---"):
         
     | 
| 85 | 
         
            +
                        return {}, text or ""
         
     | 
| 86 | 
         
            +
                    m = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text)
         
     | 
| 87 | 
         
            +
                    if not m:
         
     | 
| 88 | 
         
            +
                        return {}, text
         
     | 
| 89 | 
         
            +
                    fm_text = m.group(1)
         
     | 
| 90 | 
         
            +
                    body = text[m.end() :]
         
     | 
| 91 | 
         
            +
                    try:
         
     | 
| 92 | 
         
            +
                        data = yaml.safe_load(fm_text)
         
     | 
| 93 | 
         
            +
                        if not isinstance(data, dict):
         
     | 
| 94 | 
         
            +
                            data = {}
         
     | 
| 95 | 
         
            +
                    except Exception:
         
     | 
| 96 | 
         
            +
                        data = {}
         
     | 
| 97 | 
         
            +
                    return data, body
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def _pipeline_docs_url(self, pipeline_tag: Optional[str]) -> Optional[str]:
         
     | 
| 100 | 
         
            +
                    base = "https://huggingface.co/docs/transformers.js/api/pipelines"
         
     | 
| 101 | 
         
            +
                    if not pipeline_tag:
         
     | 
| 102 | 
         
            +
                        return base
         
     | 
| 103 | 
         
            +
                    mapping = {
         
     | 
| 104 | 
         
            +
                        "text-classification": "TextClassificationPipeline",
         
     | 
| 105 | 
         
            +
                        "token-classification": "TokenClassificationPipeline",
         
     | 
| 106 | 
         
            +
                        "question-answering": "QuestionAnsweringPipeline",
         
     | 
| 107 | 
         
            +
                        "fill-mask": "FillMaskPipeline",
         
     | 
| 108 | 
         
            +
                        "text2text-generation": "Text2TextGenerationPipeline",
         
     | 
| 109 | 
         
            +
                        "summarization": "SummarizationPipeline",
         
     | 
| 110 | 
         
            +
                        "translation": "TranslationPipeline",
         
     | 
| 111 | 
         
            +
                        "text-generation": "TextGenerationPipeline",
         
     | 
| 112 | 
         
            +
                        "zero-shot-classification": "ZeroShotClassificationPipeline",
         
     | 
| 113 | 
         
            +
                        "feature-extraction": "FeatureExtractionPipeline",
         
     | 
| 114 | 
         
            +
                        "image-feature-extraction": "ImageFeatureExtractionPipeline",
         
     | 
| 115 | 
         
            +
                        "audio-classification": "AudioClassificationPipeline",
         
     | 
| 116 | 
         
            +
                        "zero-shot-audio-classification": "ZeroShotAudioClassificationPipeline",
         
     | 
| 117 | 
         
            +
                        "automatic-speech-recognition": "AutomaticSpeechRecognitionPipeline",
         
     | 
| 118 | 
         
            +
                        "image-to-text": "ImageToTextPipeline",
         
     | 
| 119 | 
         
            +
                        "image-classification": "ImageClassificationPipeline",
         
     | 
| 120 | 
         
            +
                        "image-segmentation": "ImageSegmentationPipeline",
         
     | 
| 121 | 
         
            +
                        "background-removal": "BackgroundRemovalPipeline",
         
     | 
| 122 | 
         
            +
                        "zero-shot-image-classification": "ZeroShotImageClassificationPipeline",
         
     | 
| 123 | 
         
            +
                        "object-detection": "ObjectDetectionPipeline",
         
     | 
| 124 | 
         
            +
                        "zero-shot-object-detection": "ZeroShotObjectDetectionPipeline",
         
     | 
| 125 | 
         
            +
                        "document-question-answering": "DocumentQuestionAnsweringPipeline",
         
     | 
| 126 | 
         
            +
                        "text-to-audio": "TextToAudioPipeline",
         
     | 
| 127 | 
         
            +
                        "image-to-image": "ImageToImagePipeline",
         
     | 
| 128 | 
         
            +
                        "depth-estimation": "DepthEstimationPipeline",
         
     | 
| 129 | 
         
            +
                    }
         
     | 
| 130 | 
         
            +
                    cls = mapping.get(pipeline_tag)
         
     | 
| 131 | 
         
            +
                    if not cls:
         
     | 
| 132 | 
         
            +
                        return base
         
     | 
| 133 | 
         
            +
                    return f"{base}#module_pipelines.{cls}"
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                def _map_pipeline_to_task(self, pipeline_tag: Optional[str]) -> Optional[str]:
         
     | 
| 136 | 
         
            +
                    if not pipeline_tag:
         
     | 
| 137 | 
         
            +
                        return None
         
     | 
| 138 | 
         
            +
                    synonyms = {
         
     | 
| 139 | 
         
            +
                        "vqa": "visual-question-answering",
         
     | 
| 140 | 
         
            +
                    }
         
     | 
| 141 | 
         
            +
                    return synonyms.get(pipeline_tag, pipeline_tag)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
             
                def setup_repository(self) -> None:
         
     | 
| 144 | 
         
             
                    """Ensure the bundled transformers.js repository is present."""
         
     | 
| 145 | 
         
             
                    if not self.config.repo_path.exists():
         
     | 
| 
         | 
|
| 194 | 
         
             
                        if output_attentions:
         
     | 
| 195 | 
         
             
                            extra_args.append("--output_attentions")
         
     | 
| 196 | 
         | 
| 197 | 
         
            +
                        try:
         
     | 
| 198 | 
         
            +
                            info = model_info(repo_id=input_model_id, token=self.config.hf_token)
         
     | 
| 199 | 
         
            +
                            task = self._map_pipeline_to_task(getattr(info, "pipeline_tag", None))
         
     | 
| 200 | 
         
            +
                            if task:
         
     | 
| 201 | 
         
            +
                                extra_args.extend(["--task", task])
         
     | 
| 202 | 
         
            +
                        except Exception:
         
     | 
| 203 | 
         
            +
                            pass
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
             
                        result = self._run_conversion_subprocess(
         
     | 
| 206 | 
         
             
                            input_model_id, extra_args=extra_args or None
         
     | 
| 207 | 
         
             
                        )
         
     | 
| 
         | 
|
| 223 | 
         | 
| 224 | 
         
             
                        readme_path = f"{model_folder_path}/README.md"
         
     | 
| 225 | 
         | 
| 226 | 
         
            +
                        with open(readme_path, "w") as file:
         
     | 
| 227 | 
         
            +
                            file.write(self.generate_readme(input_model_id))
         
     | 
| 
         | 
|
| 228 | 
         | 
| 229 | 
         
             
                        self.api.upload_folder(
         
     | 
| 230 | 
         
             
                            folder_path=str(model_folder_path), repo_id=output_model_id
         
     | 
| 
         | 
|
| 236 | 
         
             
                        shutil.rmtree(model_folder_path, ignore_errors=True)
         
     | 
| 237 | 
         | 
| 238 | 
         
             
                def generate_readme(self, imi: str):
         
     | 
| 239 | 
         
            +
                    try:
         
     | 
| 240 | 
         
            +
                        info = model_info(repo_id=imi, token=self.config.hf_token)
         
     | 
| 241 | 
         
            +
                        pipeline_tag = getattr(info, "pipeline_tag", None)
         
     | 
| 242 | 
         
            +
                    except Exception:
         
     | 
| 243 | 
         
            +
                        pipeline_tag = None
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    original_text = self._fetch_original_readme(imi)
         
     | 
| 246 | 
         
            +
                    original_meta, original_body = self._extract_yaml_frontmatter(original_text)
         
     | 
| 247 | 
         
            +
                    original_body = (
         
     | 
| 248 | 
         
            +
                        original_body or self._strip_yaml_frontmatter(original_text)
         
     | 
| 249 | 
         
            +
                    ).strip()
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    merged_meta = {}
         
     | 
| 252 | 
         
            +
                    if isinstance(original_meta, dict):
         
     | 
| 253 | 
         
            +
                        merged_meta.update(original_meta)
         
     | 
| 254 | 
         
            +
                    merged_meta["library_name"] = "transformers.js"
         
     | 
| 255 | 
         
            +
                    merged_meta["base_model"] = [imi]
         
     | 
| 256 | 
         
            +
                    if pipeline_tag is not None:
         
     | 
| 257 | 
         
            +
                        merged_meta["pipeline_tag"] = pipeline_tag
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    fm_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip()
         
     | 
| 260 | 
         
            +
                    header = f"---\n{fm_yaml}\n---\n\n"
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    parts: List[str] = []
         
     | 
| 263 | 
         
            +
                    parts.append(header)
         
     | 
| 264 | 
         
            +
                    parts.append(f"# {imi.split('/')[-1]} (ONNX)\n")
         
     | 
| 265 | 
         
            +
                    parts.append(
         
     | 
| 266 | 
         
             
                        f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
         
     | 
| 267 | 
         
             
                        "It was automatically converted and uploaded using "
         
     | 
| 268 | 
         
            +
                        "[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
         
     | 
| 269 | 
         
             
                    )
         
     | 
| 270 | 
         | 
| 271 | 
         
            +
                    docs_url = self._pipeline_docs_url(pipeline_tag)
         
     | 
| 272 | 
         
            +
                    if docs_url:
         
     | 
| 273 | 
         
            +
                        parts.append("\n## Usage with Transformers.js\n")
         
     | 
| 274 | 
         
            +
                        if pipeline_tag:
         
     | 
| 275 | 
         
            +
                            parts.append(
         
     | 
| 276 | 
         
            +
                                f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
         
     | 
| 277 | 
         
            +
                            )
         
     | 
| 278 | 
         
            +
                        else:
         
     | 
| 279 | 
         
            +
                            parts.append(f"See the pipelines documentation: {docs_url}")
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    if original_body:
         
     | 
| 282 | 
         
            +
                        parts.append("\n---\n")
         
     | 
| 283 | 
         
            +
                        parts.append(original_body)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    return "\n\n".join(parts) + "\n"
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         | 
| 288 | 
         
             
            def main():
         
     | 
| 289 | 
         
             
                """Main application entry point."""
         
     | 
| 
         | 
|
| 320 | 
         | 
| 321 | 
         
             
                    if config.hf_username == input_model_id.split("/")[0]:
         
     | 
| 322 | 
         
             
                        same_repo = st.checkbox(
         
     | 
| 323 | 
         
            +
                            "Upload the ONNX weights to the existing repository"
         
     | 
| 324 | 
         
             
                        )
         
     | 
| 325 | 
         
             
                    else:
         
     | 
| 326 | 
         
             
                        same_repo = False
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,5 +1,6 @@ 
     | 
|
| 1 | 
         
             
            huggingface_hub==0.35.3
         
     | 
| 2 | 
         
             
            streamlit==1.50.0
         
     | 
| 
         | 
|
| 3 | 
         
             
            onnxscript==0.5.4
         
     | 
| 4 | 
         
             
            onnxconverter_common==1.16.0
         
     | 
| 5 | 
         
             
            onnx_graphsurgeon==0.5.8
         
     | 
| 
         | 
|
| 1 | 
         
             
            huggingface_hub==0.35.3
         
     | 
| 2 | 
         
             
            streamlit==1.50.0
         
     | 
| 3 | 
         
            +
            PyYAML==6.0.2
         
     | 
| 4 | 
         
             
            onnxscript==0.5.4
         
     | 
| 5 | 
         
             
            onnxconverter_common==1.16.0
         
     | 
| 6 | 
         
             
            onnx_graphsurgeon==0.5.8
         
     |