import os import json import threading from omegaconf import OmegaConf from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf # Global cache for downloaded models to avoid repeated downloads # Key: (repo_id, model_revision, model_hub) # Value: repo_cache_dir _model_cache = {} _cache_lock = threading.Lock() def download_model(**kwargs): model_hub = kwargs.get("model_hub", "ms") model_or_path = kwargs.get("model") repo_path = kwargs.get("repo_path", "") # Handle name mapping based on model_hub if model_hub == "ms" and model_or_path in name_maps_ms: model_or_path = name_maps_ms[model_or_path] elif model_hub == "hf" and model_or_path in name_maps_hf: model_or_path = name_maps_hf[model_or_path] model_revision = kwargs.get("model_revision") # Download model if it doesn't exist locally if not os.path.exists(model_or_path): if model_hub == "local": # For local models, the path should already exist raise FileNotFoundError(f"Local model path does not exist: {model_or_path}") elif model_hub in ["ms", "hf"]: repo_path, model_or_path = get_or_download_model_dir( model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True), model_hub=model_hub, ) else: raise ValueError(f"Unsupported model_hub: {model_hub}") print(f"Using model path: {model_or_path}") kwargs["model_path"] = model_or_path kwargs["repo_path"] = repo_path # Common logic for processing configuration files (same for all model hubs) if os.path.exists(os.path.join(model_or_path, "configuration.json")): with open( os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8" ) as f: conf_json = json.load(f) cfg = {} add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg) cfg.update(kwargs) config = OmegaConf.load(cfg["config"]) kwargs = OmegaConf.merge(config, cfg) kwargs["model"] = config["model"] elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists( os.path.join(model_or_path, "model.pt") ): config = OmegaConf.load(os.path.join(model_or_path, "config.yaml")) kwargs = OmegaConf.merge(config, kwargs) init_param = os.path.join(model_or_path, "model.pb") kwargs["init_param"] = init_param if os.path.exists(os.path.join(model_or_path, "tokens.txt")): kwargs["tokenizer_conf"]["token_list"] = os.path.join( model_or_path, "tokens.txt" ) if os.path.exists(os.path.join(model_or_path, "tokens.json")): kwargs["tokenizer_conf"]["token_list"] = os.path.join( model_or_path, "tokens.json" ) if os.path.exists(os.path.join(model_or_path, "seg_dict")): kwargs["tokenizer_conf"]["seg_dict"] = os.path.join( model_or_path, "seg_dict" ) if os.path.exists(os.path.join(model_or_path, "bpe.model")): kwargs["tokenizer_conf"]["bpemodel"] = os.path.join( model_or_path, "bpe.model" ) kwargs["model"] = config["model"] if os.path.exists(os.path.join(model_or_path, "am.mvn")): kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict") return OmegaConf.to_container(kwargs, resolve=True) def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}): if isinstance(file_path_metas, dict): for k, v in file_path_metas.items(): if isinstance(v, str): p = os.path.join(model_or_path, v) if os.path.exists(p): cfg[k] = p elif isinstance(v, dict): if k not in cfg: cfg[k] = {} add_file_root_path(model_or_path, v, cfg[k]) return cfg def get_or_download_model_dir( model, model_revision=None, is_training=False, check_latest=True, model_hub="ms", ): """Get local model directory or download model if necessary. Args: model (str): model id or path to local model directory. For HF subfolders, use format: "repo_id/subfolder_path" model_revision (str, optional): model version number. is_training (bool): Whether this is for training check_latest (bool): Whether to check for latest version model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace) """ # Extract repo_id for caching (handle subfolder case) if "/" in model and len(model.split("/")) > 2: parts = model.split("/") repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX" subfolder = "/".join(parts[2:]) # e.g., "subfolder/model" else: repo_id = model subfolder = None # Create cache key cache_key = (repo_id, model_revision, model_hub) # Check cache first with _cache_lock: if cache_key in _model_cache: cached_repo_dir = _model_cache[cache_key] print(f"Using cached model for {repo_id}: {cached_repo_dir}") # For subfolder case, construct the model_cache_dir from cached repo if subfolder: model_cache_dir = os.path.join(cached_repo_dir, subfolder) if not os.path.exists(model_cache_dir): raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}") else: model_cache_dir = cached_repo_dir return cached_repo_dir, model_cache_dir # Cache miss, need to download if model_hub == "ms": # ModelScope download from modelscope.hub.snapshot_download import snapshot_download from modelscope.utils.constant import Invoke, ThirdParty key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE # Download the repo (use repo_id, not the full model path with subfolder) repo_cache_dir = snapshot_download( repo_id, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}, ) repo_cache_dir = normalize_cache_path(repo_cache_dir) # Construct model_cache_dir if subfolder: model_cache_dir = os.path.join(repo_cache_dir, subfolder) if not os.path.exists(model_cache_dir): raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}") else: model_cache_dir = normalize_cache_path(repo_cache_dir) elif model_hub == "hf": # HuggingFace download try: from huggingface_hub import snapshot_download except ImportError: raise ImportError( "huggingface_hub is required for downloading from HuggingFace. " "Please install it with: pip install huggingface_hub" ) # Download the repo (use repo_id, not the full model path with subfolder) repo_cache_dir = snapshot_download( repo_id=repo_id, revision=model_revision, allow_patterns=None, # Download all files to ensure resource files are available ) repo_cache_dir = normalize_cache_path(repo_cache_dir) # Construct model_cache_dir if subfolder: model_cache_dir = os.path.join(repo_cache_dir, subfolder) if not os.path.exists(model_cache_dir): raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}") else: model_cache_dir = normalize_cache_path(repo_cache_dir) else: raise ValueError(f"Unsupported model_hub: {model_hub}") # Cache the result before returning with _cache_lock: _model_cache[cache_key] = repo_cache_dir print(f"Model downloaded to: {model_cache_dir}") return repo_cache_dir, model_cache_dir def normalize_cache_path(cache_path): """Normalize cache path to ensure consistent format with snapshots/{commit_id}.""" # Check if the cache_path directory contains a snapshots folder snapshots_dir = os.path.join(cache_path, "snapshots") if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir): # Find the commit_id subdirectory in snapshots try: snapshot_items = os.listdir(snapshots_dir) # Look for the first directory (should be the commit_id) for item in snapshot_items: item_path = os.path.join(snapshots_dir, item) if os.path.isdir(item_path): # Found commit_id directory, return the full path return os.path.join(cache_path, "snapshots", item) except OSError: pass # If no snapshots directory found or error occurred, return original path return cache_path