Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import threading | |
| from omegaconf import OmegaConf | |
| from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf | |
| # Global cache for downloaded models to avoid repeated downloads | |
| # Key: (repo_id, model_revision, model_hub) | |
| # Value: repo_cache_dir | |
| _model_cache = {} | |
| _cache_lock = threading.Lock() | |
| def download_model(**kwargs): | |
| model_hub = kwargs.get("model_hub", "ms") | |
| model_or_path = kwargs.get("model") | |
| repo_path = kwargs.get("repo_path", "") | |
| # Handle name mapping based on model_hub | |
| if model_hub == "ms" and model_or_path in name_maps_ms: | |
| model_or_path = name_maps_ms[model_or_path] | |
| elif model_hub == "hf" and model_or_path in name_maps_hf: | |
| model_or_path = name_maps_hf[model_or_path] | |
| model_revision = kwargs.get("model_revision") | |
| # Download model if it doesn't exist locally | |
| if not os.path.exists(model_or_path): | |
| if model_hub == "local": | |
| # For local models, the path should already exist | |
| raise FileNotFoundError(f"Local model path does not exist: {model_or_path}") | |
| elif model_hub in ["ms", "hf"]: | |
| repo_path, model_or_path = get_or_download_model_dir( | |
| model_or_path, | |
| model_revision, | |
| is_training=kwargs.get("is_training"), | |
| check_latest=kwargs.get("kwargs", True), | |
| model_hub=model_hub, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model_hub: {model_hub}") | |
| print(f"Using model path: {model_or_path}") | |
| kwargs["model_path"] = model_or_path | |
| kwargs["repo_path"] = repo_path | |
| # Common logic for processing configuration files (same for all model hubs) | |
| if os.path.exists(os.path.join(model_or_path, "configuration.json")): | |
| with open( | |
| os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8" | |
| ) as f: | |
| conf_json = json.load(f) | |
| cfg = {} | |
| add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg) | |
| cfg.update(kwargs) | |
| config = OmegaConf.load(cfg["config"]) | |
| kwargs = OmegaConf.merge(config, cfg) | |
| kwargs["model"] = config["model"] | |
| elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists( | |
| os.path.join(model_or_path, "model.pt") | |
| ): | |
| config = OmegaConf.load(os.path.join(model_or_path, "config.yaml")) | |
| kwargs = OmegaConf.merge(config, kwargs) | |
| init_param = os.path.join(model_or_path, "model.pb") | |
| kwargs["init_param"] = init_param | |
| if os.path.exists(os.path.join(model_or_path, "tokens.txt")): | |
| kwargs["tokenizer_conf"]["token_list"] = os.path.join( | |
| model_or_path, "tokens.txt" | |
| ) | |
| if os.path.exists(os.path.join(model_or_path, "tokens.json")): | |
| kwargs["tokenizer_conf"]["token_list"] = os.path.join( | |
| model_or_path, "tokens.json" | |
| ) | |
| if os.path.exists(os.path.join(model_or_path, "seg_dict")): | |
| kwargs["tokenizer_conf"]["seg_dict"] = os.path.join( | |
| model_or_path, "seg_dict" | |
| ) | |
| if os.path.exists(os.path.join(model_or_path, "bpe.model")): | |
| kwargs["tokenizer_conf"]["bpemodel"] = os.path.join( | |
| model_or_path, "bpe.model" | |
| ) | |
| kwargs["model"] = config["model"] | |
| if os.path.exists(os.path.join(model_or_path, "am.mvn")): | |
| kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") | |
| if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): | |
| kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict") | |
| return OmegaConf.to_container(kwargs, resolve=True) | |
| def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}): | |
| if isinstance(file_path_metas, dict): | |
| for k, v in file_path_metas.items(): | |
| if isinstance(v, str): | |
| p = os.path.join(model_or_path, v) | |
| if os.path.exists(p): | |
| cfg[k] = p | |
| elif isinstance(v, dict): | |
| if k not in cfg: | |
| cfg[k] = {} | |
| add_file_root_path(model_or_path, v, cfg[k]) | |
| return cfg | |
| def get_or_download_model_dir( | |
| model, | |
| model_revision=None, | |
| is_training=False, | |
| check_latest=True, | |
| model_hub="ms", | |
| ): | |
| """Get local model directory or download model if necessary. | |
| Args: | |
| model (str): model id or path to local model directory. | |
| For HF subfolders, use format: "repo_id/subfolder_path" | |
| model_revision (str, optional): model version number. | |
| is_training (bool): Whether this is for training | |
| check_latest (bool): Whether to check for latest version | |
| model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace) | |
| """ | |
| # Extract repo_id for caching (handle subfolder case) | |
| if "/" in model and len(model.split("/")) > 2: | |
| parts = model.split("/") | |
| repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX" | |
| subfolder = "/".join(parts[2:]) # e.g., "subfolder/model" | |
| else: | |
| repo_id = model | |
| subfolder = None | |
| # Create cache key | |
| cache_key = (repo_id, model_revision, model_hub) | |
| # Check cache first | |
| with _cache_lock: | |
| if cache_key in _model_cache: | |
| cached_repo_dir = _model_cache[cache_key] | |
| print(f"Using cached model for {repo_id}: {cached_repo_dir}") | |
| # For subfolder case, construct the model_cache_dir from cached repo | |
| if subfolder: | |
| model_cache_dir = os.path.join(cached_repo_dir, subfolder) | |
| if not os.path.exists(model_cache_dir): | |
| raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}") | |
| else: | |
| model_cache_dir = cached_repo_dir | |
| return cached_repo_dir, model_cache_dir | |
| # Cache miss, need to download | |
| if model_hub == "ms": | |
| # ModelScope download | |
| from modelscope.hub.snapshot_download import snapshot_download | |
| from modelscope.utils.constant import Invoke, ThirdParty | |
| key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE | |
| # Download the repo (use repo_id, not the full model path with subfolder) | |
| repo_cache_dir = snapshot_download( | |
| repo_id, | |
| revision=model_revision, | |
| user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}, | |
| ) | |
| repo_cache_dir = normalize_cache_path(repo_cache_dir) | |
| # Construct model_cache_dir | |
| if subfolder: | |
| model_cache_dir = os.path.join(repo_cache_dir, subfolder) | |
| if not os.path.exists(model_cache_dir): | |
| raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}") | |
| else: | |
| model_cache_dir = normalize_cache_path(repo_cache_dir) | |
| elif model_hub == "hf": | |
| # HuggingFace download | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| raise ImportError( | |
| "huggingface_hub is required for downloading from HuggingFace. " | |
| "Please install it with: pip install huggingface_hub" | |
| ) | |
| # Download the repo (use repo_id, not the full model path with subfolder) | |
| repo_cache_dir = snapshot_download( | |
| repo_id=repo_id, | |
| revision=model_revision, | |
| allow_patterns=None, # Download all files to ensure resource files are available | |
| ) | |
| repo_cache_dir = normalize_cache_path(repo_cache_dir) | |
| # Construct model_cache_dir | |
| if subfolder: | |
| model_cache_dir = os.path.join(repo_cache_dir, subfolder) | |
| if not os.path.exists(model_cache_dir): | |
| raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}") | |
| else: | |
| model_cache_dir = normalize_cache_path(repo_cache_dir) | |
| else: | |
| raise ValueError(f"Unsupported model_hub: {model_hub}") | |
| # Cache the result before returning | |
| with _cache_lock: | |
| _model_cache[cache_key] = repo_cache_dir | |
| print(f"Model downloaded to: {model_cache_dir}") | |
| return repo_cache_dir, model_cache_dir | |
| def normalize_cache_path(cache_path): | |
| """Normalize cache path to ensure consistent format with snapshots/{commit_id}.""" | |
| # Check if the cache_path directory contains a snapshots folder | |
| snapshots_dir = os.path.join(cache_path, "snapshots") | |
| if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir): | |
| # Find the commit_id subdirectory in snapshots | |
| try: | |
| snapshot_items = os.listdir(snapshots_dir) | |
| # Look for the first directory (should be the commit_id) | |
| for item in snapshot_items: | |
| item_path = os.path.join(snapshots_dir, item) | |
| if os.path.isdir(item_path): | |
| # Found commit_id directory, return the full path | |
| return os.path.join(cache_path, "snapshots", item) | |
| except OSError: | |
| pass | |
| # If no snapshots directory found or error occurred, return original path | |
| return cache_path | |