Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,337 Bytes
6852edb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import json
import threading
from omegaconf import OmegaConf
from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf
# Global cache for downloaded models to avoid repeated downloads
# Key: (repo_id, model_revision, model_hub)
# Value: repo_cache_dir
_model_cache = {}
_cache_lock = threading.Lock()
def download_model(**kwargs):
model_hub = kwargs.get("model_hub", "ms")
model_or_path = kwargs.get("model")
repo_path = kwargs.get("repo_path", "")
# Handle name mapping based on model_hub
if model_hub == "ms" and model_or_path in name_maps_ms:
model_or_path = name_maps_ms[model_or_path]
elif model_hub == "hf" and model_or_path in name_maps_hf:
model_or_path = name_maps_hf[model_or_path]
model_revision = kwargs.get("model_revision")
# Download model if it doesn't exist locally
if not os.path.exists(model_or_path):
if model_hub == "local":
# For local models, the path should already exist
raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
elif model_hub in ["ms", "hf"]:
repo_path, model_or_path = get_or_download_model_dir(
model_or_path,
model_revision,
is_training=kwargs.get("is_training"),
check_latest=kwargs.get("kwargs", True),
model_hub=model_hub,
)
else:
raise ValueError(f"Unsupported model_hub: {model_hub}")
print(f"Using model path: {model_or_path}")
kwargs["model_path"] = model_or_path
kwargs["repo_path"] = repo_path
# Common logic for processing configuration files (same for all model hubs)
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(
os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
) as f:
conf_json = json.load(f)
cfg = {}
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
cfg.update(kwargs)
config = OmegaConf.load(cfg["config"])
kwargs = OmegaConf.merge(config, cfg)
kwargs["model"] = config["model"]
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
os.path.join(model_or_path, "model.pt")
):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
model_or_path, "tokens.txt"
)
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
model_or_path, "tokens.json"
)
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
model_or_path, "seg_dict"
)
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
model_or_path, "bpe.model"
)
kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
if isinstance(file_path_metas, dict):
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
cfg[k] = p
elif isinstance(v, dict):
if k not in cfg:
cfg[k] = {}
add_file_root_path(model_or_path, v, cfg[k])
return cfg
def get_or_download_model_dir(
model,
model_revision=None,
is_training=False,
check_latest=True,
model_hub="ms",
):
"""Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
For HF subfolders, use format: "repo_id/subfolder_path"
model_revision (str, optional): model version number.
is_training (bool): Whether this is for training
check_latest (bool): Whether to check for latest version
model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
"""
# Extract repo_id for caching (handle subfolder case)
if "/" in model and len(model.split("/")) > 2:
parts = model.split("/")
repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
subfolder = "/".join(parts[2:]) # e.g., "subfolder/model"
else:
repo_id = model
subfolder = None
# Create cache key
cache_key = (repo_id, model_revision, model_hub)
# Check cache first
with _cache_lock:
if cache_key in _model_cache:
cached_repo_dir = _model_cache[cache_key]
print(f"Using cached model for {repo_id}: {cached_repo_dir}")
# For subfolder case, construct the model_cache_dir from cached repo
if subfolder:
model_cache_dir = os.path.join(cached_repo_dir, subfolder)
if not os.path.exists(model_cache_dir):
raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
else:
model_cache_dir = cached_repo_dir
return cached_repo_dir, model_cache_dir
# Cache miss, need to download
if model_hub == "ms":
# ModelScope download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Invoke, ThirdParty
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
# Download the repo (use repo_id, not the full model path with subfolder)
repo_cache_dir = snapshot_download(
repo_id,
revision=model_revision,
user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
)
repo_cache_dir = normalize_cache_path(repo_cache_dir)
# Construct model_cache_dir
if subfolder:
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
if not os.path.exists(model_cache_dir):
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
else:
model_cache_dir = normalize_cache_path(repo_cache_dir)
elif model_hub == "hf":
# HuggingFace download
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"huggingface_hub is required for downloading from HuggingFace. "
"Please install it with: pip install huggingface_hub"
)
# Download the repo (use repo_id, not the full model path with subfolder)
repo_cache_dir = snapshot_download(
repo_id=repo_id,
revision=model_revision,
allow_patterns=None, # Download all files to ensure resource files are available
)
repo_cache_dir = normalize_cache_path(repo_cache_dir)
# Construct model_cache_dir
if subfolder:
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
if not os.path.exists(model_cache_dir):
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
else:
model_cache_dir = normalize_cache_path(repo_cache_dir)
else:
raise ValueError(f"Unsupported model_hub: {model_hub}")
# Cache the result before returning
with _cache_lock:
_model_cache[cache_key] = repo_cache_dir
print(f"Model downloaded to: {model_cache_dir}")
return repo_cache_dir, model_cache_dir
def normalize_cache_path(cache_path):
"""Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
# Check if the cache_path directory contains a snapshots folder
snapshots_dir = os.path.join(cache_path, "snapshots")
if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
# Find the commit_id subdirectory in snapshots
try:
snapshot_items = os.listdir(snapshots_dir)
# Look for the first directory (should be the commit_id)
for item in snapshot_items:
item_path = os.path.join(snapshots_dir, item)
if os.path.isdir(item_path):
# Found commit_id directory, return the full path
return os.path.join(cache_path, "snapshots", item)
except OSError:
pass
# If no snapshots directory found or error occurred, return original path
return cache_path
|