|
|
"""Utility helpers for Kaloscope-related model management.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from typing import Any, Dict, Iterable, Tuple |
|
|
|
|
|
|
|
|
def _iter_repo_files(repo_info: Any) -> Iterable[Any]: |
|
|
"""Yield file metadata objects from a repository info response.""" |
|
|
siblings = getattr(repo_info, "siblings", None) |
|
|
if not siblings: |
|
|
return [] |
|
|
return siblings |
|
|
|
|
|
|
|
|
def auto_register_kaloscope_variants(models: Dict[str, Dict[str, Any]]) -> None: |
|
|
"""Discover Kaloscope checkpoints in subfolders and append them to the model list.""" |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
except ImportError: |
|
|
print("huggingface_hub is not available; skipping Kaloscope auto-discovery.") |
|
|
return |
|
|
|
|
|
repo_id = "heathcliff01/Kaloscope" |
|
|
api = HfApi() |
|
|
|
|
|
try: |
|
|
repo_info = api.model_info(repo_id=repo_id) |
|
|
except Exception as exc: |
|
|
print(f"Failed to query Hugging Face repo info: {exc}") |
|
|
return |
|
|
|
|
|
existing_entries: set[Tuple[str, str, str]] = { |
|
|
( |
|
|
config.get("repo_id", ""), |
|
|
config.get("subfolder", ""), |
|
|
config.get("filename", ""), |
|
|
) |
|
|
for config in models.values() |
|
|
if config.get("repo_id") and config.get("filename") |
|
|
} |
|
|
|
|
|
discovered: list[Tuple[str, Dict[str, Any], Tuple[str, str, str]]] = [] |
|
|
|
|
|
for sibling in _iter_repo_files(repo_info): |
|
|
path = getattr(sibling, "rfilename", None) or getattr(sibling, "path", None) |
|
|
if not path or not path.endswith(".pth"): |
|
|
continue |
|
|
|
|
|
subfolder = os.path.dirname(path).strip("/") |
|
|
filename = os.path.basename(path) |
|
|
|
|
|
key = (repo_id, subfolder, filename) |
|
|
if key in existing_entries: |
|
|
continue |
|
|
|
|
|
|
|
|
repo_name = repo_id.split("/")[-1] |
|
|
if subfolder: |
|
|
display_name = f"{repo_name}/{subfolder.split('/')[-1]}" |
|
|
else: |
|
|
display_name = f"{repo_name}/{os.path.splitext(filename)[0]}" |
|
|
|
|
|
config: Dict[str, Any] = { |
|
|
"type": "pytorch", |
|
|
"path": "", |
|
|
"repo_id": repo_id, |
|
|
"filename": filename, |
|
|
"arch": "lsnet_xl_artist", |
|
|
} |
|
|
|
|
|
if subfolder: |
|
|
config["subfolder"] = subfolder |
|
|
|
|
|
discovered.append((display_name, config, key)) |
|
|
|
|
|
|
|
|
discovered.sort(key=lambda item: item[0]) |
|
|
|
|
|
for display_name, config, key in discovered: |
|
|
candidate_name = display_name |
|
|
suffix = 2 |
|
|
while candidate_name in models: |
|
|
candidate_name = f"{display_name} #{suffix}" |
|
|
suffix += 1 |
|
|
models[candidate_name] = config |
|
|
existing_entries.add(key) |
|
|
print(f"Auto-registered: {candidate_name}") |
|
|
|
|
|
|
|
|
def validate_models(models: Dict[str, Dict[str, Any]]): |
|
|
""" |
|
|
Validate models at startup and remove invalid entries. |
|
|
Returns list of valid model names. |
|
|
""" |
|
|
valid_models = {} |
|
|
|
|
|
for model_name, config in models.items(): |
|
|
model_path = config.get("path") or "" |
|
|
|
|
|
|
|
|
if model_path and os.path.exists(model_path): |
|
|
valid_models[model_name] = config |
|
|
continue |
|
|
|
|
|
|
|
|
if "repo_id" in config and "filename" in config: |
|
|
try: |
|
|
from huggingface_hub import file_exists |
|
|
|
|
|
|
|
|
check_kwargs = { |
|
|
"repo_id": config["repo_id"], |
|
|
"filename": config["filename"], |
|
|
"repo_type": "model", |
|
|
} |
|
|
if config.get("subfolder"): |
|
|
check_kwargs["filename"] = f"{config['subfolder']}/{config['filename']}" |
|
|
if config.get("revision"): |
|
|
check_kwargs["revision"] = config["revision"] |
|
|
|
|
|
|
|
|
if file_exists(**check_kwargs): |
|
|
valid_models[model_name] = config |
|
|
else: |
|
|
print(f"Skipping {model_name}: file not found on Hugging Face") |
|
|
except Exception as e: |
|
|
print(f"Skipping {model_name}: validation error - {e}") |
|
|
else: |
|
|
print(f"Skipping {model_name}: no valid path or repo configuration") |
|
|
|
|
|
return valid_models |
|
|
|