|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Configuration utility functions |
|
|
""" |
|
|
|
|
|
import importlib |
|
|
from typing import Any, Callable, List, Union |
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
|
|
|
|
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: |
|
|
""" |
|
|
Load a configuration. Will resolve inheritance. |
|
|
""" |
|
|
config = OmegaConf.load(path) |
|
|
if argv is not None: |
|
|
config_argv = OmegaConf.from_dotlist(argv) |
|
|
config = OmegaConf.merge(config, config_argv) |
|
|
config = resolve_recursive(config, resolve_inheritance) |
|
|
return config |
|
|
|
|
|
|
|
|
def resolve_recursive( |
|
|
config: Any, |
|
|
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], |
|
|
) -> Any: |
|
|
config = resolver(config) |
|
|
if isinstance(config, DictConfig): |
|
|
for k in config.keys(): |
|
|
v = config.get(k) |
|
|
if isinstance(v, (DictConfig, ListConfig)): |
|
|
config[k] = resolve_recursive(v, resolver) |
|
|
if isinstance(config, ListConfig): |
|
|
for i in range(len(config)): |
|
|
v = config.get(i) |
|
|
if isinstance(v, (DictConfig, ListConfig)): |
|
|
config[i] = resolve_recursive(v, resolver) |
|
|
return config |
|
|
|
|
|
|
|
|
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: |
|
|
""" |
|
|
Recursively resolve inheritance if the config contains: |
|
|
__inherit__: path/to/parent.yaml or a ListConfig of such paths. |
|
|
""" |
|
|
if isinstance(config, DictConfig): |
|
|
inherit = config.pop("__inherit__", None) |
|
|
|
|
|
if inherit: |
|
|
inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] |
|
|
|
|
|
parent_config = None |
|
|
for parent_path in inherit_list: |
|
|
assert isinstance(parent_path, str) |
|
|
parent_config = ( |
|
|
load_config(parent_path) |
|
|
if parent_config is None |
|
|
else OmegaConf.merge(parent_config, load_config(parent_path)) |
|
|
) |
|
|
|
|
|
if len(config.keys()) > 0: |
|
|
config = OmegaConf.merge(parent_config, config) |
|
|
else: |
|
|
config = parent_config |
|
|
return config |
|
|
|
|
|
|
|
|
def import_item(path: str, name: str) -> Any: |
|
|
""" |
|
|
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass |
|
|
""" |
|
|
return getattr(importlib.import_module(path), name) |
|
|
|
|
|
|
|
|
def create_object(config: DictConfig) -> Any: |
|
|
""" |
|
|
Create an object from config. |
|
|
The config is expected to contains the following: |
|
|
__object__: |
|
|
path: path.to.module |
|
|
name: MyClass |
|
|
args: as_config | as_params (default to as_config) |
|
|
""" |
|
|
item = import_item( |
|
|
path=config.__object__.path, |
|
|
name=config.__object__.name, |
|
|
) |
|
|
args = config.__object__.get("args", "as_config") |
|
|
if args == "as_config": |
|
|
return item(config) |
|
|
if args == "as_params": |
|
|
config = OmegaConf.to_object(config) |
|
|
config.pop("__object__") |
|
|
return item(**config) |
|
|
raise NotImplementedError(f"Unknown args type: {args}") |