File size: 4,144 Bytes
cef9e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import importlib
import yaml
import os
from pathlib import Path
from typing import Any, IO
from omegaconf import OmegaConf

class IncludeLoader(yaml.SafeLoader):
    """
    Class extending the YAML Loader to handle nested documents
    YAML Loader with `!include` constructor.
    From: https://gist.github.com/joshbode/569627ced3076931b02f
    """
    def __init__(self, stream: IO) -> None:
        """
        Initialise Loader
        """
        # Registers the current directory as the root directory
        self.root = os.path.curdir

        super().__init__(stream)

class BaseConfiguration:
    """
    Represents the configuration parameters for running the process
    """
    def __init__(self, path):
        """
        Initializes the configuration with contents from the specified file
        :param path: path to the configuration file in json format
        """
        with open(path, 'r') as f:
            yaml_object = yaml.load(f, IncludeLoader)

        # Loads the configuration file and converts it to a dictionary
        omegaconf_config = OmegaConf.create(yaml_object, flags={"allow_objects": True}) # Uses the experimental "allow_objects" flag to allow classes and functions to be stored directly in the configuration
        self.config = OmegaConf.to_container(omegaconf_config, resolve=True)

        # Checks the configuration
        self.check_config()

        # Creates the directory structure
        self.create_directory_structure()

    def get_config(self):
        return self.config

    def check_config(self):
        """
        Checks that the configuration is well-formed
        Raises an exception if it is not
        :return:
        """
        pass

    def create_directory_structure(self):
        """
        Creates the directory structure needed by the configuration
        Eg. logging/checkpoints/results directories
        :return:
        """
        pass


class Configuration(BaseConfiguration):
    """
    Represents the configuration parameters for running the process
    """
    def __init__(self, path):
        """
        Initializes the configuration with contents from the specified file
        :param path: path to the configuration file in json format
        """
        super().__init__(path)

    def create_directory_structure(self):
        """
        Creates the directory structure needed by the configuration
        Eg. logging/checkpoints/results directories
        :return:
        """
        if "logging" in self.config and "checkpoints_directory" in self.config["logging"]:
            Path(self.config["logging"]["checkpoints_directory"]).mkdir(parents=True, exist_ok=True)

def get_class_by_name(name):
    """
    Gets a class by its fully qualified name
    :param name: fully qualified class name eg "mypackage.mymodule.MyClass"
    :return: the requested class
    """
    splits = name.split('.')
    module_name = '.'.join(splits[:-1])
    class_name = splits[-1]
    loaded_module = importlib.import_module(module_name)
    loaded_class = getattr(loaded_module, class_name)

    return loaded_class

def construct_include(loader: IncludeLoader, node: yaml.Node) -> Any:
    """
    Manages inclusion of the file referenced at node
    """
    filename = os.path.abspath(os.path.join(loader.root, loader.construct_scalar(node)))
    extension = os.path.splitext(filename)[1].lstrip('.')

    with open(filename, 'r') as f:
        if extension in ('yaml', 'yml'):
            return yaml.load(f, IncludeLoader) # Check if nested documents are handled correctly
        elif extension in ('json', ):
            return json.load(f)
        else:
            return ''.join(f.readlines())

def construct_module(loader: IncludeLoader, node: yaml.Node) -> Any:
    """
    Manages inclusion of a referenced function into the file
    """
    function_module_name = loader.construct_scalar(node)
    function = get_class_by_name(function_module_name)
    return function



# Registers the loader
yaml.add_constructor('!include', construct_include, IncludeLoader)
yaml.add_constructor('!module', construct_module, IncludeLoader)