Spaces:
Runtime error
Runtime error
| import itertools | |
| import os | |
| import re | |
| import tempfile | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from urllib.parse import quote, unquote | |
| import fsspec | |
| from ._commit_api import CommitOperationCopy, CommitOperationDelete | |
| from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES | |
| from .hf_api import HfApi | |
| from .utils import ( | |
| EntryNotFoundError, | |
| HFValidationError, | |
| RepositoryNotFoundError, | |
| RevisionNotFoundError, | |
| hf_raise_for_status, | |
| http_backoff, | |
| paginate, | |
| parse_datetime, | |
| ) | |
| # Regex used to match special revisions with "/" in them (see #1710) | |
| SPECIAL_REFS_REVISION_REGEX = re.compile( | |
| r""" | |
| (^refs\/convert\/parquet) # `refs/convert/parquet` revisions | |
| | | |
| (^refs\/pr\/\d+) # PR revisions | |
| """, | |
| re.VERBOSE, | |
| ) | |
| class HfFileSystemResolvedPath: | |
| """Data structure containing information about a resolved Hugging Face file system path.""" | |
| repo_type: str | |
| repo_id: str | |
| revision: str | |
| path_in_repo: str | |
| def unresolve(self) -> str: | |
| return ( | |
| f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_revision(self.revision)}/{self.path_in_repo}" | |
| .rstrip("/") | |
| ) | |
| class HfFileSystem(fsspec.AbstractFileSystem): | |
| """ | |
| Access a remote Hugging Face Hub repository as if were a local file system. | |
| Args: | |
| endpoint (`str`, *optional*): | |
| The endpoint to use. If not provided, the default one (https://huggingface.co) is used. | |
| token (`str`, *optional*): | |
| Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token. | |
| Usage: | |
| ```python | |
| >>> from huggingface_hub import HfFileSystem | |
| >>> fs = HfFileSystem() | |
| >>> # List files | |
| >>> fs.glob("my-username/my-model/*.bin") | |
| ['my-username/my-model/pytorch_model.bin'] | |
| >>> fs.ls("datasets/my-username/my-dataset", detail=False) | |
| ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] | |
| >>> # Read/write files | |
| >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: | |
| ... data = f.read() | |
| >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: | |
| ... f.write(data) | |
| ``` | |
| """ | |
| root_marker = "" | |
| protocol = "hf" | |
| def __init__( | |
| self, | |
| *args, | |
| endpoint: Optional[str] = None, | |
| token: Optional[str] = None, | |
| **storage_options, | |
| ): | |
| super().__init__(*args, **storage_options) | |
| self.endpoint = endpoint or ENDPOINT | |
| self.token = token | |
| self._api = HfApi(endpoint=endpoint, token=token) | |
| # Maps (repo_type, repo_id, revision) to a 2-tuple with: | |
| # * the 1st element indicating whether the repositoy and the revision exist | |
| # * the 2nd element being the exception raised if the repository or revision doesn't exist | |
| self._repo_and_revision_exists_cache: Dict[ | |
| Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] | |
| ] = {} | |
| def _repo_and_revision_exist( | |
| self, repo_type: str, repo_id: str, revision: Optional[str] | |
| ) -> Tuple[bool, Optional[Exception]]: | |
| if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: | |
| try: | |
| self._api.repo_info(repo_id, revision=revision, repo_type=repo_type) | |
| except (RepositoryNotFoundError, HFValidationError) as e: | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e | |
| except RevisionNotFoundError as e: | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None | |
| else: | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None | |
| self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None | |
| return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] | |
| def exists(self, path, **kwargs): | |
| """Is there a file at the given path | |
| Exact same implementation as in fsspec except that instead of catching all exceptions, we only catch when it's | |
| not a `NotImplementedError` (which we do want to raise). Catching a `NotImplementedError` can lead to undesired | |
| behavior. | |
| Adapted from https://github.com/fsspec/filesystem_spec/blob/f5d24b80a0768bf07a113647d7b4e74a3a2999e0/fsspec/spec.py#L649C1-L656C25 | |
| """ | |
| try: | |
| self.info(path, **kwargs) | |
| return True | |
| except Exception as e: # noqa: E722 | |
| if isinstance(e, NotImplementedError): | |
| raise | |
| # any exception allowed bar FileNotFoundError? | |
| return False | |
| def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: | |
| def _align_revision_in_path_with_revision( | |
| revision_in_path: Optional[str], revision: Optional[str] | |
| ) -> Optional[str]: | |
| if revision is not None: | |
| if revision_in_path is not None and revision_in_path != revision: | |
| raise ValueError( | |
| f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' | |
| " are not the same." | |
| ) | |
| else: | |
| revision = revision_in_path | |
| return revision | |
| path = self._strip_protocol(path) | |
| if not path: | |
| # can't list repositories at root | |
| raise NotImplementedError("Access to repositories lists is not implemented.") | |
| elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values(): | |
| if "/" not in path: | |
| # can't list repositories at the repository type level | |
| raise NotImplementedError("Access to repositories lists is not implemented.") | |
| repo_type, path = path.split("/", 1) | |
| repo_type = REPO_TYPES_MAPPING[repo_type] | |
| else: | |
| repo_type = REPO_TYPE_MODEL | |
| if path.count("/") > 0: | |
| if "@" in path: | |
| repo_id, revision_in_path = path.split("@", 1) | |
| if "/" in revision_in_path: | |
| match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) | |
| if match is not None and revision in (None, match.group()): | |
| # Handle `refs/convert/parquet` and PR revisions separately | |
| path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") | |
| revision_in_path = match.group() | |
| else: | |
| revision_in_path, path_in_repo = revision_in_path.split("/", 1) | |
| else: | |
| path_in_repo = "" | |
| revision_in_path = unquote(revision_in_path) | |
| revision = _align_revision_in_path_with_revision(revision_in_path, revision) | |
| repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
| if not repo_and_revision_exist: | |
| raise FileNotFoundError(path) from err | |
| else: | |
| repo_id_with_namespace = "/".join(path.split("/")[:2]) | |
| path_in_repo_with_namespace = "/".join(path.split("/")[2:]) | |
| repo_id_without_namespace = path.split("/")[0] | |
| path_in_repo_without_namespace = "/".join(path.split("/")[1:]) | |
| repo_id = repo_id_with_namespace | |
| path_in_repo = path_in_repo_with_namespace | |
| repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
| if not repo_and_revision_exist: | |
| if isinstance(err, (RepositoryNotFoundError, HFValidationError)): | |
| repo_id = repo_id_without_namespace | |
| path_in_repo = path_in_repo_without_namespace | |
| repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
| if not repo_and_revision_exist: | |
| raise FileNotFoundError(path) from err | |
| else: | |
| raise FileNotFoundError(path) from err | |
| else: | |
| repo_id = path | |
| path_in_repo = "" | |
| if "@" in path: | |
| repo_id, revision_in_path = path.split("@", 1) | |
| revision_in_path = unquote(revision_in_path) | |
| revision = _align_revision_in_path_with_revision(revision_in_path, revision) | |
| repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
| if not repo_and_revision_exist: | |
| raise NotImplementedError("Access to repositories lists is not implemented.") | |
| revision = revision if revision is not None else DEFAULT_REVISION | |
| return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo) | |
| def invalidate_cache(self, path: Optional[str] = None) -> None: | |
| if not path: | |
| self.dircache.clear() | |
| self._repository_type_and_id_exists_cache.clear() | |
| else: | |
| path = self.resolve_path(path).unresolve() | |
| while path: | |
| self.dircache.pop(path, None) | |
| path = self._parent(path) | |
| def _open( | |
| self, | |
| path: str, | |
| mode: str = "rb", | |
| revision: Optional[str] = None, | |
| **kwargs, | |
| ) -> "HfFileSystemFile": | |
| if mode == "ab": | |
| raise NotImplementedError("Appending to remote files is not yet supported.") | |
| return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs) | |
| def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: | |
| resolved_path = self.resolve_path(path, revision=revision) | |
| self._api.delete_file( | |
| path_in_repo=resolved_path.path_in_repo, | |
| repo_id=resolved_path.repo_id, | |
| token=self.token, | |
| repo_type=resolved_path.repo_type, | |
| revision=resolved_path.revision, | |
| commit_message=kwargs.get("commit_message"), | |
| commit_description=kwargs.get("commit_description"), | |
| ) | |
| self.invalidate_cache(path=resolved_path.unresolve()) | |
| def rm( | |
| self, | |
| path: str, | |
| recursive: bool = False, | |
| maxdepth: Optional[int] = None, | |
| revision: Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| resolved_path = self.resolve_path(path, revision=revision) | |
| root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id | |
| paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=resolved_path.revision) | |
| paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)] | |
| operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] | |
| commit_message = f"Delete {path} " | |
| commit_message += "recursively " if recursive else "" | |
| commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" | |
| # TODO: use `commit_description` to list all the deleted paths? | |
| self._api.create_commit( | |
| repo_id=resolved_path.repo_id, | |
| repo_type=resolved_path.repo_type, | |
| token=self.token, | |
| operations=operations, | |
| revision=resolved_path.revision, | |
| commit_message=kwargs.get("commit_message", commit_message), | |
| commit_description=kwargs.get("commit_description"), | |
| ) | |
| self.invalidate_cache(path=resolved_path.unresolve()) | |
| def ls( | |
| self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs | |
| ) -> List[Union[str, Dict[str, Any]]]: | |
| """List the contents of a directory.""" | |
| resolved_path = self.resolve_path(path, revision=revision) | |
| revision_in_path = "@" + safe_revision(resolved_path.revision) | |
| has_revision_in_path = revision_in_path in path | |
| path = resolved_path.unresolve() | |
| if path not in self.dircache or refresh: | |
| path_prefix = ( | |
| HfFileSystemResolvedPath( | |
| resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, "" | |
| ).unresolve() | |
| + "/" | |
| ) | |
| tree_path = path | |
| tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) | |
| try: | |
| tree_item = next(tree_iter) | |
| except EntryNotFoundError: | |
| if "/" in resolved_path.path_in_repo: | |
| tree_path = self._parent(path) | |
| tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) | |
| else: | |
| raise | |
| else: | |
| tree_iter = itertools.chain([tree_item], tree_iter) | |
| child_infos = [] | |
| for tree_item in tree_iter: | |
| child_info = { | |
| "name": path_prefix + tree_item["path"], | |
| "size": tree_item["size"], | |
| "type": tree_item["type"], | |
| } | |
| if tree_item["type"] == "file": | |
| child_info.update( | |
| { | |
| "blob_id": tree_item["oid"], | |
| "lfs": tree_item.get("lfs"), | |
| "last_modified": parse_datetime(tree_item["lastCommit"]["date"]), | |
| }, | |
| ) | |
| child_infos.append(child_info) | |
| self.dircache[tree_path] = child_infos | |
| out = self._ls_from_cache(path) | |
| if not has_revision_in_path: | |
| out = [{**o, "name": o["name"].replace(revision_in_path, "", 1)} for o in out] | |
| return out if detail else [o["name"] for o in out] | |
| def _iter_tree(self, path: str, revision: Optional[str] = None): | |
| # TODO: use HfApi.list_files_info instead when it supports "lastCommit" and "expand=True" | |
| # See https://github.com/huggingface/moon-landing/issues/5993 | |
| resolved_path = self.resolve_path(path, revision=revision) | |
| path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip( | |
| "/" | |
| ) | |
| headers = self._api._build_hf_headers() | |
| yield from paginate(path, params={"expand": True}, headers=headers) | |
| def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: | |
| resolved_path1 = self.resolve_path(path1, revision=revision) | |
| resolved_path2 = self.resolve_path(path2, revision=revision) | |
| same_repo = ( | |
| resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id | |
| ) | |
| # TODO: Wait for https://github.com/huggingface/huggingface_hub/issues/1083 to be resolved to simplify this logic | |
| if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None: | |
| commit_message = f"Copy {path1} to {path2}" | |
| self._api.create_commit( | |
| repo_id=resolved_path1.repo_id, | |
| repo_type=resolved_path1.repo_type, | |
| revision=resolved_path2.revision, | |
| commit_message=kwargs.get("commit_message", commit_message), | |
| commit_description=kwargs.get("commit_description", ""), | |
| operations=[ | |
| CommitOperationCopy( | |
| src_path_in_repo=resolved_path1.path_in_repo, | |
| path_in_repo=resolved_path2.path_in_repo, | |
| src_revision=resolved_path1.revision, | |
| ) | |
| ], | |
| ) | |
| else: | |
| with self.open(path1, "rb", revision=resolved_path1.revision) as f: | |
| content = f.read() | |
| commit_message = f"Copy {path1} to {path2}" | |
| self._api.upload_file( | |
| path_or_fileobj=content, | |
| path_in_repo=resolved_path2.path_in_repo, | |
| repo_id=resolved_path2.repo_id, | |
| token=self.token, | |
| repo_type=resolved_path2.repo_type, | |
| revision=resolved_path2.revision, | |
| commit_message=kwargs.get("commit_message", commit_message), | |
| commit_description=kwargs.get("commit_description"), | |
| ) | |
| self.invalidate_cache(path=resolved_path1.unresolve()) | |
| self.invalidate_cache(path=resolved_path2.unresolve()) | |
| def modified(self, path: str, **kwargs) -> datetime: | |
| info = self.info(path, **kwargs) | |
| if "last_modified" not in info: | |
| raise IsADirectoryError(path) | |
| return info["last_modified"] | |
| def info(self, path: str, **kwargs) -> Dict[str, Any]: | |
| resolved_path = self.resolve_path(path) | |
| if not resolved_path.path_in_repo: | |
| revision_in_path = "@" + safe_revision(resolved_path.revision) | |
| has_revision_in_path = revision_in_path in path | |
| name = resolved_path.unresolve() | |
| name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name | |
| return {"name": name, "size": 0, "type": "directory"} | |
| return super().info(path, **kwargs) | |
| class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): | |
| def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): | |
| super().__init__(fs, path, **kwargs) | |
| self.fs: HfFileSystem | |
| self.resolved_path = fs.resolve_path(path, revision=revision) | |
| def _fetch_range(self, start: int, end: int) -> bytes: | |
| headers = { | |
| "range": f"bytes={start}-{end - 1}", | |
| **self.fs._api._build_hf_headers(), | |
| } | |
| url = ( | |
| f"{self.fs.endpoint}/{REPO_TYPES_URL_PREFIXES.get(self.resolved_path.repo_type, '') + self.resolved_path.repo_id}/resolve/{safe_quote(self.resolved_path.revision)}/{safe_quote(self.resolved_path.path_in_repo)}" | |
| ) | |
| r = http_backoff("GET", url, headers=headers) | |
| hf_raise_for_status(r) | |
| return r.content | |
| def _initiate_upload(self) -> None: | |
| self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) | |
| def _upload_chunk(self, final: bool = False) -> None: | |
| self.buffer.seek(0) | |
| block = self.buffer.read() | |
| self.temp_file.write(block) | |
| if final: | |
| self.temp_file.close() | |
| self.fs._api.upload_file( | |
| path_or_fileobj=self.temp_file.name, | |
| path_in_repo=self.resolved_path.path_in_repo, | |
| repo_id=self.resolved_path.repo_id, | |
| token=self.fs.token, | |
| repo_type=self.resolved_path.repo_type, | |
| revision=self.resolved_path.revision, | |
| commit_message=self.kwargs.get("commit_message"), | |
| commit_description=self.kwargs.get("commit_description"), | |
| ) | |
| os.remove(self.temp_file.name) | |
| self.fs.invalidate_cache( | |
| path=self.resolved_path.unresolve(), | |
| ) | |
| def safe_revision(revision: str) -> str: | |
| return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) | |
| def safe_quote(s: str) -> str: | |
| return quote(s, safe="") | |