Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| """ | |
| Utility that checks that modules like attention processors are listed in the documentation file. | |
| ```bash | |
| python utils/check_support_list.py | |
| ``` | |
| It has no auto-fix mode. | |
| """ | |
| import os | |
| import re | |
| # All paths are set with the intent that you run this script from the root of the repo | |
| REPO_PATH = "." | |
| def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"): | |
| """ | |
| Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class. | |
| Returns a list of documented class names (just the class name portion). | |
| """ | |
| with open(os.path.join(REPO_PATH, doc_path), "r") as f: | |
| doctext = f.read() | |
| matches = re.findall(autodoc_regex, doctext) | |
| return [match.split(".")[-1] for match in matches] | |
| def read_source_classes(src_path, class_regex, exclude_conditions=None): | |
| """ | |
| Reads class names from a source file using a regex that captures class definitions. | |
| Optionally exclude classes based on a list of conditions (functions that take class name and return bool). | |
| """ | |
| if exclude_conditions is None: | |
| exclude_conditions = [] | |
| with open(os.path.join(REPO_PATH, src_path), "r") as f: | |
| doctext = f.read() | |
| classes = re.findall(class_regex, doctext) | |
| # Filter out classes that meet any of the exclude conditions | |
| filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)] | |
| return filtered_classes | |
| def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None): | |
| """ | |
| Generic function to check if all classes defined in `src_path` are documented in `doc_path`. | |
| Returns a set of undocumented class names. | |
| """ | |
| documented = set(read_documented_classes(doc_path, doc_regex)) | |
| source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions)) | |
| # Find which classes in source are not documented in a deterministic way. | |
| undocumented = sorted(source_classes - documented) | |
| return undocumented | |
| if __name__ == "__main__": | |
| # Define the checks we need to perform | |
| checks = { | |
| "Attention Processors": { | |
| "doc_path": "docs/source/en/api/attnprocessor.md", | |
| "src_path": "src/diffusers/models/attention_processor.py", | |
| "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
| "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | |
| "exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"], | |
| }, | |
| "Image Processors": { | |
| "doc_path": "docs/source/en/api/image_processor.md", | |
| "src_path": "src/diffusers/image_processor.py", | |
| "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
| "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | |
| }, | |
| "Activations": { | |
| "doc_path": "docs/source/en/api/activations.md", | |
| "src_path": "src/diffusers/models/activations.py", | |
| "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
| "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | |
| }, | |
| "Normalizations": { | |
| "doc_path": "docs/source/en/api/normalization.md", | |
| "src_path": "src/diffusers/models/normalization.py", | |
| "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
| "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | |
| "exclude_conditions": [ | |
| # Exclude LayerNorm as it's an intentional exception | |
| lambda c: c == "LayerNorm" | |
| ], | |
| }, | |
| "LoRA Mixins": { | |
| "doc_path": "docs/source/en/api/loaders/lora.md", | |
| "src_path": "src/diffusers/loaders/lora_pipeline.py", | |
| "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
| "src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]", | |
| }, | |
| } | |
| missing_items = {} | |
| for category, params in checks.items(): | |
| undocumented = check_documentation( | |
| doc_path=params["doc_path"], | |
| src_path=params["src_path"], | |
| doc_regex=params["doc_regex"], | |
| src_regex=params["src_regex"], | |
| exclude_conditions=params.get("exclude_conditions"), | |
| ) | |
| if undocumented: | |
| missing_items[category] = undocumented | |
| # If we have any missing items, raise a single combined error | |
| if missing_items: | |
| error_msg = ["Some classes are not documented properly:\n"] | |
| for category, classes in missing_items.items(): | |
| error_msg.append(f"- {category}: {', '.join(sorted(classes))}") | |
| raise ValueError("\n".join(error_msg)) | |