Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2025 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Utility that sorts the imports in the custom inits of Diffusers. Diffusers uses init files that delay the | |
| import of an object to when it's actually needed. This is to avoid the main init importing all models, which would | |
| make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with | |
| delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the | |
| objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff` | |
| properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half. | |
| Use from the root of the repo with: | |
| ```bash | |
| python utils/custom_init_isort.py | |
| ``` | |
| which will auto-sort the imports (used in `make style`). | |
| For a check only (as used in `make quality`) run: | |
| ```bash | |
| python utils/custom_init_isort.py --check_only | |
| ``` | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| from typing import Any, Callable, List, Optional | |
| # Path is defined with the intent you should run this script from the root of the repo. | |
| PATH_TO_TRANSFORMERS = "src/diffusers" | |
| # Pattern that looks at the indentation in a line. | |
| _re_indent = re.compile(r"^(\s*)\S") | |
| # Pattern that matches `"key":" and puts `key` in group 0. | |
| _re_direct_key = re.compile(r'^\s*"([^"]+)":') | |
| # Pattern that matches `_import_structure["key"]` and puts `key` in group 0. | |
| _re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') | |
| # Pattern that matches `"key",` and puts `key` in group 0. | |
| _re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') | |
| # Pattern that matches any `[stuff]` and puts `stuff` in group 0. | |
| _re_bracket_content = re.compile(r"\[([^\]]+)\]") | |
| def get_indent(line: str) -> str: | |
| """Returns the indent in given line (as string).""" | |
| search = _re_indent.search(line) | |
| return "" if search is None else search.groups()[0] | |
| def split_code_in_indented_blocks( | |
| code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None | |
| ) -> List[str]: | |
| """ | |
| Split some code into its indented blocks, starting at a given level. | |
| Args: | |
| code (`str`): The code to split. | |
| indent_level (`str`): The indent level (as string) to use for identifying the blocks to split. | |
| start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is. | |
| end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is. | |
| Warning: | |
| The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code` | |
| can thus be retrieved by joining the result. | |
| Returns: | |
| `List[str]`: The list of blocks. | |
| """ | |
| # Let's split the code into lines and move to start_index. | |
| index = 0 | |
| lines = code.split("\n") | |
| if start_prompt is not None: | |
| while not lines[index].startswith(start_prompt): | |
| index += 1 | |
| blocks = ["\n".join(lines[:index])] | |
| else: | |
| blocks = [] | |
| # This variable contains the block treated at a given time. | |
| current_block = [lines[index]] | |
| index += 1 | |
| # We split into blocks until we get to the `end_prompt` (or the end of the file). | |
| while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): | |
| # We have a non-empty line with the proper indent -> start of a new block | |
| if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: | |
| # Store the current block in the result and rest. There are two cases: the line is part of the block (like | |
| # a closing parenthesis) or not. | |
| if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): | |
| # Line is part of the current block | |
| current_block.append(lines[index]) | |
| blocks.append("\n".join(current_block)) | |
| if index < len(lines) - 1: | |
| current_block = [lines[index + 1]] | |
| index += 1 | |
| else: | |
| current_block = [] | |
| else: | |
| # Line is not part of the current block | |
| blocks.append("\n".join(current_block)) | |
| current_block = [lines[index]] | |
| else: | |
| # Just add the line to the current block | |
| current_block.append(lines[index]) | |
| index += 1 | |
| # Adds current block if it's nonempty. | |
| if len(current_block) > 0: | |
| blocks.append("\n".join(current_block)) | |
| # Add final block after end_prompt if provided. | |
| if end_prompt is not None and index < len(lines): | |
| blocks.append("\n".join(lines[index:])) | |
| return blocks | |
| def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]: | |
| """ | |
| Wraps a key function (as used in a sort) to lowercase and ignore underscores. | |
| """ | |
| def _inner(x): | |
| return key(x).lower().replace("_", "") | |
| return _inner | |
| def sort_objects(objects: List[Any], key: Optional[Callable[[Any], str]] = None) -> List[Any]: | |
| """ | |
| Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased | |
| last). | |
| Args: | |
| objects (`List[Any]`): | |
| The list of objects to sort. | |
| key (`Callable[[Any], str]`, *optional*): | |
| A function taking an object as input and returning a string, used to sort them by alphabetical order. | |
| If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string). | |
| Returns: | |
| `List[Any]`: The sorted list with the same elements as in the inputs | |
| """ | |
| # If no key is provided, we use a noop. | |
| def noop(x): | |
| return x | |
| if key is None: | |
| key = noop | |
| # Constants are all uppercase, they go first. | |
| constants = [obj for obj in objects if key(obj).isupper()] | |
| # Classes are not all uppercase but start with a capital, they go second. | |
| classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] | |
| # Functions begin with a lowercase, they go last. | |
| functions = [obj for obj in objects if not key(obj)[0].isupper()] | |
| # Then we sort each group. | |
| key1 = ignore_underscore_and_lowercase(key) | |
| return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) | |
| def sort_objects_in_import(import_statement: str) -> str: | |
| """ | |
| Sorts the imports in a single import statement. | |
| Args: | |
| import_statement (`str`): The import statement in which to sort the imports. | |
| Returns: | |
| `str`: The same as the input, but with objects properly sorted. | |
| """ | |
| # This inner function sort imports between [ ]. | |
| def _replace(match): | |
| imports = match.groups()[0] | |
| # If there is one import only, nothing to do. | |
| if "," not in imports: | |
| return f"[{imports}]" | |
| keys = [part.strip().replace('"', "") for part in imports.split(",")] | |
| # We will have a final empty element if the line finished with a comma. | |
| if len(keys[-1]) == 0: | |
| keys = keys[:-1] | |
| return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" | |
| lines = import_statement.split("\n") | |
| if len(lines) > 3: | |
| # Here we have to sort internal imports that are on several lines (one per name): | |
| # key: [ | |
| # "object1", | |
| # "object2", | |
| # ... | |
| # ] | |
| # We may have to ignore one or two lines on each side. | |
| idx = 2 if lines[1].strip() == "[" else 1 | |
| keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] | |
| sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) | |
| sorted_lines = [lines[x[0] + idx] for x in sorted_indices] | |
| return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) | |
| elif len(lines) == 3: | |
| # Here we have to sort internal imports that are on one separate line: | |
| # key: [ | |
| # "object1", "object2", ... | |
| # ] | |
| if _re_bracket_content.search(lines[1]) is not None: | |
| lines[1] = _re_bracket_content.sub(_replace, lines[1]) | |
| else: | |
| keys = [part.strip().replace('"', "") for part in lines[1].split(",")] | |
| # We will have a final empty element if the line finished with a comma. | |
| if len(keys[-1]) == 0: | |
| keys = keys[:-1] | |
| lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) | |
| return "\n".join(lines) | |
| else: | |
| # Finally we have to deal with imports fitting on one line | |
| import_statement = _re_bracket_content.sub(_replace, import_statement) | |
| return import_statement | |
| def sort_imports(file: str, check_only: bool = True): | |
| """ | |
| Sort the imports defined in the `_import_structure` of a given init. | |
| Args: | |
| file (`str`): The path to the init to check/fix. | |
| check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init. | |
| """ | |
| with open(file, encoding="utf-8") as f: | |
| code = f.read() | |
| # If the file is not a custom init, there is nothing to do. | |
| if "_import_structure" not in code: | |
| return | |
| # Blocks of indent level 0 | |
| main_blocks = split_code_in_indented_blocks( | |
| code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" | |
| ) | |
| # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). | |
| for block_idx in range(1, len(main_blocks) - 1): | |
| # Check if the block contains some `_import_structure`s thingy to sort. | |
| block = main_blocks[block_idx] | |
| block_lines = block.split("\n") | |
| # Get to the start of the imports. | |
| line_idx = 0 | |
| while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]: | |
| # Skip dummy import blocks | |
| if "import dummy" in block_lines[line_idx]: | |
| line_idx = len(block_lines) | |
| else: | |
| line_idx += 1 | |
| if line_idx >= len(block_lines): | |
| continue | |
| # Ignore beginning and last line: they don't contain anything. | |
| internal_block_code = "\n".join(block_lines[line_idx:-1]) | |
| indent = get_indent(block_lines[1]) | |
| # Slit the internal block into blocks of indent level 1. | |
| internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) | |
| # We have two categories of import key: list or _import_structure[key].append/extend | |
| pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key | |
| # Grab the keys, but there is a trap: some lines are empty or just comments. | |
| keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] | |
| # We only sort the lines with a key. | |
| keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] | |
| sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] | |
| # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest. | |
| count = 0 | |
| reordered_blocks = [] | |
| for i in range(len(internal_blocks)): | |
| if keys[i] is None: | |
| reordered_blocks.append(internal_blocks[i]) | |
| else: | |
| block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) | |
| reordered_blocks.append(block) | |
| count += 1 | |
| # And we put our main block back together with its first and last line. | |
| main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]]) | |
| if code != "\n".join(main_blocks): | |
| if check_only: | |
| return True | |
| else: | |
| print(f"Overwriting {file}.") | |
| with open(file, "w", encoding="utf-8") as f: | |
| f.write("\n".join(main_blocks)) | |
| def sort_imports_in_all_inits(check_only=True): | |
| """ | |
| Sort the imports defined in the `_import_structure` of all inits in the repo. | |
| Args: | |
| check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init. | |
| """ | |
| failures = [] | |
| for root, _, files in os.walk(PATH_TO_TRANSFORMERS): | |
| if "__init__.py" in files: | |
| result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) | |
| if result: | |
| failures = [os.path.join(root, "__init__.py")] | |
| if len(failures) > 0: | |
| raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") | |
| args = parser.parse_args() | |
| sort_imports_in_all_inits(check_only=args.check_only) | |