Spaces:
Runtime error
Runtime error
| import time | |
| import os | |
| from datetime import timedelta | |
| from loguru import logger | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | |
| from huggingface_hub.utils import ( | |
| LocalEntryNotFoundError, | |
| EntryNotFoundError, | |
| RevisionNotFoundError, # Import here to ease try/except in other part of the lib | |
| ) | |
| WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) | |
| def weight_hub_files( | |
| model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" | |
| ) -> List[str]: | |
| """Get the weights filenames on the hub""" | |
| api = HfApi() | |
| info = api.model_info(model_id, revision=revision) | |
| filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] | |
| if not filenames: | |
| raise EntryNotFoundError( | |
| f"No {extension} weights found for model {model_id} and revision {revision}.", | |
| None, | |
| ) | |
| return filenames | |
| def try_to_load_from_cache( | |
| model_id: str, revision: Optional[str], filename: str | |
| ) -> Optional[Path]: | |
| """Try to load a file from the Hugging Face cache""" | |
| if revision is None: | |
| revision = "main" | |
| object_id = model_id.replace("/", "--") | |
| repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" | |
| if not repo_cache.is_dir(): | |
| # No cache for this model | |
| return None | |
| refs_dir = repo_cache / "refs" | |
| snapshots_dir = repo_cache / "snapshots" | |
| # Resolve refs (for instance to convert main to the associated commit sha) | |
| if refs_dir.is_dir(): | |
| revision_file = refs_dir / revision | |
| if revision_file.exists(): | |
| with revision_file.open() as f: | |
| revision = f.read() | |
| # Check if revision folder exists | |
| if not snapshots_dir.exists(): | |
| return None | |
| cached_shas = os.listdir(snapshots_dir) | |
| if revision not in cached_shas: | |
| # No cache for this revision and we won't try to return a random revision | |
| return None | |
| # Check if file exists in cache | |
| cached_file = snapshots_dir / revision / filename | |
| return cached_file if cached_file.is_file() else None | |
| def weight_files( | |
| model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" | |
| ) -> List[Path]: | |
| """Get the local files""" | |
| # Local model | |
| if Path(model_id).exists() and Path(model_id).is_dir(): | |
| return list(Path(model_id).glob(f"*{extension}")) | |
| try: | |
| filenames = weight_hub_files(model_id, revision, extension) | |
| except EntryNotFoundError as e: | |
| if extension != ".safetensors": | |
| raise e | |
| # Try to see if there are pytorch weights | |
| pt_filenames = weight_hub_files(model_id, revision, extension=".bin") | |
| # Change pytorch extension to safetensors extension | |
| # It is possible that we have safetensors weights locally even though they are not on the | |
| # hub if we converted weights locally without pushing them | |
| filenames = [ | |
| f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames | |
| ] | |
| if WEIGHTS_CACHE_OVERRIDE is not None: | |
| files = [] | |
| for filename in filenames: | |
| p = Path(WEIGHTS_CACHE_OVERRIDE) / filename | |
| if not p.exists(): | |
| raise LocalEntryNotFoundError( | |
| f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." | |
| ) | |
| files.append(p) | |
| return files | |
| files = [] | |
| for filename in filenames: | |
| cache_file = try_to_load_from_cache( | |
| model_id, revision=revision, filename=filename | |
| ) | |
| if cache_file is None: | |
| raise LocalEntryNotFoundError( | |
| f"File {filename} of model {model_id} not found in " | |
| f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " | |
| f"Please run `text-generation-server download-weights {model_id}` first." | |
| ) | |
| files.append(cache_file) | |
| return files | |
| def download_weights( | |
| filenames: List[str], model_id: str, revision: Optional[str] = None | |
| ) -> List[Path]: | |
| """Download the safetensors files from the hub""" | |
| def download_file(filename): | |
| local_file = try_to_load_from_cache(model_id, revision, filename) | |
| if local_file is not None: | |
| logger.info(f"File {filename} already present in cache.") | |
| return Path(local_file) | |
| logger.info(f"Download file: {filename}") | |
| start_time = time.time() | |
| local_file = hf_hub_download( | |
| filename=filename, | |
| repo_id=model_id, | |
| revision=revision, | |
| local_files_only=False, | |
| ) | |
| logger.info( | |
| f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." | |
| ) | |
| return Path(local_file) | |
| # We do this instead of using tqdm because we want to parse the logs with the launcher | |
| start_time = time.time() | |
| files = [] | |
| for i, filename in enumerate(filenames): | |
| file = download_file(filename) | |
| elapsed = timedelta(seconds=int(time.time() - start_time)) | |
| remaining = len(filenames) - (i + 1) | |
| eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0 | |
| logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}") | |
| files.append(file) | |
| return files | |