File size: 9,337 Bytes
6852edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import json
import threading
from omegaconf import OmegaConf

from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf

# Global cache for downloaded models to avoid repeated downloads
# Key: (repo_id, model_revision, model_hub)
# Value: repo_cache_dir
_model_cache = {}
_cache_lock = threading.Lock()


def download_model(**kwargs):
    model_hub = kwargs.get("model_hub", "ms")
    model_or_path = kwargs.get("model")
    repo_path = kwargs.get("repo_path", "")

    # Handle name mapping based on model_hub
    if model_hub == "ms" and model_or_path in name_maps_ms:
        model_or_path = name_maps_ms[model_or_path]
    elif model_hub == "hf" and model_or_path in name_maps_hf:
        model_or_path = name_maps_hf[model_or_path]

    model_revision = kwargs.get("model_revision")

    # Download model if it doesn't exist locally
    if not os.path.exists(model_or_path):
        if model_hub == "local":
            # For local models, the path should already exist
            raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
        elif model_hub in ["ms", "hf"]:
            repo_path, model_or_path = get_or_download_model_dir(
                model_or_path,
                model_revision,
                is_training=kwargs.get("is_training"),
                check_latest=kwargs.get("kwargs", True),
                model_hub=model_hub,
            )
        else:
            raise ValueError(f"Unsupported model_hub: {model_hub}")

    print(f"Using model path: {model_or_path}")
    kwargs["model_path"] = model_or_path
    kwargs["repo_path"] = repo_path

    # Common logic for processing configuration files (same for all model hubs)
    if os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(
            os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
        ) as f:
            conf_json = json.load(f)
            cfg = {}
            add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
            cfg.update(kwargs)
            config = OmegaConf.load(cfg["config"])
            kwargs = OmegaConf.merge(config, cfg)
        kwargs["model"] = config["model"]
    elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
        os.path.join(model_or_path, "model.pt")
    ):
        config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
        kwargs = OmegaConf.merge(config, kwargs)
        init_param = os.path.join(model_or_path, "model.pb")
        kwargs["init_param"] = init_param
        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(
                model_or_path, "tokens.txt"
            )
        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(
                model_or_path, "tokens.json"
            )
        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
                model_or_path, "seg_dict"
            )
        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
                model_or_path, "bpe.model"
            )
        kwargs["model"] = config["model"]
        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")

    return OmegaConf.to_container(kwargs, resolve=True)


def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):

    if isinstance(file_path_metas, dict):
        for k, v in file_path_metas.items():
            if isinstance(v, str):
                p = os.path.join(model_or_path, v)
                if os.path.exists(p):
                    cfg[k] = p
            elif isinstance(v, dict):
                if k not in cfg:
                    cfg[k] = {}
                add_file_root_path(model_or_path, v, cfg[k])

    return cfg


def get_or_download_model_dir(
    model,
    model_revision=None,
    is_training=False,
    check_latest=True,
    model_hub="ms",
):
    """Get local model directory or download model if necessary.

    Args:
        model (str): model id or path to local model directory.
                    For HF subfolders, use format: "repo_id/subfolder_path"
        model_revision  (str, optional): model version number.
        is_training (bool): Whether this is for training
        check_latest (bool): Whether to check for latest version
        model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
    """
    # Extract repo_id for caching (handle subfolder case)
    if "/" in model and len(model.split("/")) > 2:
        parts = model.split("/")
        repo_id = "/".join(parts[:2])  # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
        subfolder = "/".join(parts[2:])  # e.g., "subfolder/model"
    else:
        repo_id = model
        subfolder = None

    # Create cache key
    cache_key = (repo_id, model_revision, model_hub)

    # Check cache first
    with _cache_lock:
        if cache_key in _model_cache:
            cached_repo_dir = _model_cache[cache_key]
            print(f"Using cached model for {repo_id}: {cached_repo_dir}")

            # For subfolder case, construct the model_cache_dir from cached repo
            if subfolder:
                model_cache_dir = os.path.join(cached_repo_dir, subfolder)
                if not os.path.exists(model_cache_dir):
                    raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
            else:
                model_cache_dir = cached_repo_dir

            return cached_repo_dir, model_cache_dir

    # Cache miss, need to download
    if model_hub == "ms":
        # ModelScope download
        from modelscope.hub.snapshot_download import snapshot_download
        from modelscope.utils.constant import Invoke, ThirdParty

        key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE

        # Download the repo (use repo_id, not the full model path with subfolder)
        repo_cache_dir = snapshot_download(
            repo_id,
            revision=model_revision,
            user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
        )
        repo_cache_dir = normalize_cache_path(repo_cache_dir)

        # Construct model_cache_dir
        if subfolder:
            model_cache_dir = os.path.join(repo_cache_dir, subfolder)
            if not os.path.exists(model_cache_dir):
                raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
        else:
            model_cache_dir = normalize_cache_path(repo_cache_dir)

    elif model_hub == "hf":
        # HuggingFace download
        try:
            from huggingface_hub import snapshot_download
        except ImportError:
            raise ImportError(
                "huggingface_hub is required for downloading from HuggingFace. "
                "Please install it with: pip install huggingface_hub"
            )

        # Download the repo (use repo_id, not the full model path with subfolder)
        repo_cache_dir = snapshot_download(
            repo_id=repo_id,
            revision=model_revision,
            allow_patterns=None,  # Download all files to ensure resource files are available
        )
        repo_cache_dir = normalize_cache_path(repo_cache_dir)

        # Construct model_cache_dir
        if subfolder:
            model_cache_dir = os.path.join(repo_cache_dir, subfolder)
            if not os.path.exists(model_cache_dir):
                raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
        else:
            model_cache_dir = normalize_cache_path(repo_cache_dir)
    else:
        raise ValueError(f"Unsupported model_hub: {model_hub}")

    # Cache the result before returning
    with _cache_lock:
        _model_cache[cache_key] = repo_cache_dir

    print(f"Model downloaded to: {model_cache_dir}")
    return repo_cache_dir, model_cache_dir

def normalize_cache_path(cache_path):
    """Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
    # Check if the cache_path directory contains a snapshots folder
    snapshots_dir = os.path.join(cache_path, "snapshots")
    if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
        # Find the commit_id subdirectory in snapshots
        try:
            snapshot_items = os.listdir(snapshots_dir)
            # Look for the first directory (should be the commit_id)
            for item in snapshot_items:
                item_path = os.path.join(snapshots_dir, item)
                if os.path.isdir(item_path):
                    # Found commit_id directory, return the full path
                    return os.path.join(cache_path, "snapshots", item)
        except OSError:
            pass

    # If no snapshots directory found or error occurred, return original path
    return cache_path